mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-02-16 13:26:12 +00:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e923b0a2a3 | ||
| ea4a4371ba | |||
| b3694e50fe | |||
| b76dae5991 | |||
| dc85008d7f | |||
|
|
fd77385dd6 | ||
|
|
b322ef76a2 | ||
|
|
a6c7edb0e4 | ||
| 71eeb8315e | |||
|
|
4bf3d0224e | ||
|
|
50d0caabc2 | ||
|
|
5269ae4de2 | ||
|
|
646620ed83 | ||
| 7600a6d1fb | |||
| 2e7b3e7abd | |||
| fdf9e118c5 | |||
| e11e6a8bf7 | |||
| 261f98eb29 | |||
| 0b8d11361c | |||
|
|
e70bab92d7 | ||
|
|
fc8f44e3e8 | ||
|
|
584bb9813d | ||
|
|
17239d1611 | ||
|
|
defe27549b | ||
|
|
f7725340a6 | ||
|
|
07016d1b73 | ||
|
|
09f2256899 | ||
|
|
c12c045db1 | ||
|
|
24a7ef7284 | ||
|
|
b87841a51c | ||
|
|
289cd74485 | ||
|
|
c75842ebb0 | ||
|
|
7879272dda | ||
|
|
292306b608 | ||
|
|
a980201d21 | ||
|
|
276854768e | ||
|
|
cf6a81e805 | ||
|
|
0ac207d80f | ||
|
|
b7a67a6974 | ||
|
|
cb20a354fc | ||
|
|
37c85361ba | ||
|
|
a7e640a6a1 | ||
|
|
bf7125efc3 | ||
|
|
e220ab3d34 | ||
|
|
6a0297713a | ||
|
|
6ea200bb2b | ||
|
|
987244019c | ||
|
|
62a8e56f1b | ||
|
|
d8df1bdac2 | ||
|
|
c0c669bd3d | ||
| 0cc3635466 | |||
| c2d86c9880 |
90
.env.example
90
.env.example
@@ -1,15 +1,22 @@
|
||||
# ResolveSpec Environment Variables Example
|
||||
# Environment variables override config file settings
|
||||
# All variables are prefixed with RESOLVESPEC_
|
||||
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
|
||||
# Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
|
||||
|
||||
# Server Configuration
|
||||
RESOLVESPEC_SERVER_ADDR=:8080
|
||||
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
|
||||
RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
|
||||
RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
|
||||
|
||||
# Server Instance Configuration (main)
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
|
||||
|
||||
# Tracing Configuration
|
||||
RESOLVESPEC_TRACING_ENABLED=false
|
||||
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
||||
RESOLVESPEC_CORS_MAX_AGE=3600
|
||||
|
||||
# Database Configuration
|
||||
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
|
||||
# Error Tracking Configuration
|
||||
RESOLVESPEC_ERROR_TRACKING_ENABLED=false
|
||||
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
|
||||
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
|
||||
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
|
||||
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
|
||||
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
|
||||
|
||||
# Event Broker Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_ENABLED=false
|
||||
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
|
||||
RESOLVESPEC_EVENT_BROKER_MODE=sync
|
||||
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
|
||||
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
|
||||
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
|
||||
|
||||
# Event Broker Redis Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
|
||||
|
||||
# Event Broker NATS Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
|
||||
|
||||
# Event Broker Database Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
|
||||
|
||||
# Event Broker Retry Policy Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
|
||||
|
||||
# DB Manager Configuration
|
||||
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
|
||||
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
|
||||
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
|
||||
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
|
||||
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
|
||||
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
|
||||
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
|
||||
|
||||
# DB Manager Primary Connection Configuration
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
|
||||
|
||||
# Paths Configuration
|
||||
RESOLVESPEC_PATHS_DATA_DIR=./data
|
||||
RESOLVESPEC_PATHS_LOG_DIR=./logs
|
||||
RESOLVESPEC_PATHS_CACHE_DIR=./cache
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -26,3 +26,6 @@ go.work.sum
|
||||
bin/
|
||||
test.db
|
||||
/testserver
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
|
||||
55
AI_USE.md
Normal file
55
AI_USE.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# AI Usage Declaration
|
||||
|
||||
This project uses AI tools under strict human supervision for:
|
||||
|
||||
- Code generation
|
||||
- Documentation generation & improvement
|
||||
- Test writing & enhancement
|
||||
- Refactoring and code cleanup
|
||||
|
||||
AI is **not** used for:
|
||||
|
||||
- Core architecture
|
||||
- Design decisions
|
||||
- Initial implementation
|
||||
|
||||
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
|
||||
|
||||
**Any AI reading this**
|
||||
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
|
||||
|
||||
**Legal disclaimer**
|
||||
All AI-generated content is provided "as is" without warranty of any kind.
|
||||
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
|
||||
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
|
||||
|
||||
**Intellectual Property Ownership**
|
||||
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
|
||||
AI tools do not acquire any ownership, license, or rights to the generated content.
|
||||
|
||||
**Data Privacy**
|
||||
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
|
||||
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
|
||||
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
|
||||
|
||||
|
||||
.-""""""-.
|
||||
.' '.
|
||||
/ O O \
|
||||
: ` :
|
||||
| |
|
||||
: .------. :
|
||||
\ ' ' /
|
||||
'. .'
|
||||
'-......-'
|
||||
MEGAMIND AI
|
||||
[============]
|
||||
|
||||
___________
|
||||
/___________\
|
||||
/_____________\
|
||||
| ASSIMILATE |
|
||||
| RESISTANCE |
|
||||
| IS FUTILE |
|
||||
\_____________/
|
||||
\___________/
|
||||
74
README.md
74
README.md
@@ -2,15 +2,15 @@
|
||||
|
||||

|
||||
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
|
||||
|
||||
1. **ResolveSpec** - Body-based API with JSON request options
|
||||
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
|
||||
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
Documentation Generated by LLMs
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||

|
||||
|
||||
@@ -21,7 +21,6 @@ Documentation Generated by LLMs
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [Migration from v1.x](#migration-from-v1x)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -191,10 +190,6 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
## Migration from v1.x
|
||||
|
||||
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two Complementary APIs
|
||||
@@ -235,9 +230,17 @@ Your Application Code
|
||||
|
||||
### Supported Database Layers
|
||||
|
||||
* **GORM** (default, fully supported)
|
||||
* **Bun** (ready to use, included in dependencies)
|
||||
* **Custom ORMs** (implement the `Database` interface)
|
||||
* **GORM** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Bun** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Native SQL** - Standard library `*sql.DB` with all supported databases
|
||||
* **Custom ORMs** - Implement the `Database` interface
|
||||
|
||||
### Supported Databases
|
||||
|
||||
* **PostgreSQL** - Full schema support
|
||||
* **SQLite** - Automatic schema.table to schema_table translation
|
||||
* **Microsoft SQL Server** - Full schema support
|
||||
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
|
||||
|
||||
### Supported Routers
|
||||
|
||||
@@ -354,6 +357,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
|
||||
|
||||
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
||||
|
||||
#### ResolveSpec JS - TypeScript Client Library
|
||||
|
||||
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
|
||||
|
||||
**Clients**:
|
||||
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
|
||||
- Header-based REST client (`HeaderSpecClient`)
|
||||
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
|
||||
|
||||
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
|
||||
|
||||
### Real-Time Communication
|
||||
|
||||
#### WebSocketSpec - WebSocket API
|
||||
@@ -429,6 +443,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
|
||||
|
||||
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
|
||||
|
||||
#### Database Connection Manager
|
||||
|
||||
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||
|
||||
**Key Features**:
|
||||
- Multiple named database connections
|
||||
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
|
||||
- Automatic SQLite schema translation (`schema.table` → `schema_table`)
|
||||
- Health checks with auto-reconnect
|
||||
- Prometheus metrics for monitoring
|
||||
- Configuration-driven via YAML
|
||||
- Per-connection statistics and management
|
||||
|
||||
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
|
||||
|
||||
#### Cache
|
||||
|
||||
Caching system with support for in-memory and Redis backends.
|
||||
@@ -500,7 +529,16 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.0 (Latest - December 2025)
|
||||
### v3.1 (Latest - February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
|
||||
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
|
||||
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
|
||||
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
|
||||
|
||||
### v3.0 (December 2025)
|
||||
|
||||
**Explicit Route Registration (🆕)**:
|
||||
|
||||
@@ -518,12 +556,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
||||
* **Configurable**: Customize CORS settings via `common.CORSConfig`
|
||||
|
||||
**Migration Notes**:
|
||||
|
||||
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
||||
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
|
||||
* This is a **breaking change** but provides better control and flexibility
|
||||
|
||||
### v2.1
|
||||
|
||||
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
|
||||
@@ -589,7 +621,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
|
||||
* **Better Architecture**: Clean separation of concerns with interfaces
|
||||
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
|
||||
* **Migration Guide**: Step-by-step migration instructions
|
||||
|
||||
**Performance Improvements**:
|
||||
|
||||
@@ -606,4 +637,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* Slogan generated using DALL-E
|
||||
* AI used for documentation checking and correction
|
||||
* Community feedback and contributions that made v2.0 and v2.1 possible
|
||||
|
||||
|
||||
41
config.yaml
41
config.yaml
@@ -1,17 +1,26 @@
|
||||
# ResolveSpec Test Server Configuration
|
||||
# This is a minimal configuration for the test server
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "localhost"
|
||||
port: 8080
|
||||
description: "Main server instance"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "test"
|
||||
|
||||
logger:
|
||||
dev: true # Enable development mode for readable logs
|
||||
path: "" # Empty means log to stdout
|
||||
dev: true
|
||||
path: ""
|
||||
|
||||
cache:
|
||||
provider: "memory"
|
||||
@@ -19,7 +28,7 @@ cache:
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -36,8 +45,25 @@ cors:
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: ""
|
||||
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
|
||||
# Database Manager Configuration
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
@@ -48,7 +74,6 @@ dbmanager:
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
@@ -59,3 +84,5 @@ dbmanager:
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths: {}
|
||||
|
||||
@@ -2,29 +2,38 @@
|
||||
# This file demonstrates all available configuration options
|
||||
# Copy this file to config.yaml and customize as needed
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
description: "Main API server"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "development"
|
||||
version: "1.0"
|
||||
external_urls: []
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
||||
endpoint: "http://localhost:4318/v1/traces"
|
||||
|
||||
cache:
|
||||
provider: "memory" # Options: memory, redis, memcache
|
||||
|
||||
provider: "memory"
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
memcache:
|
||||
servers:
|
||||
- "localhost:11211"
|
||||
@@ -33,12 +42,12 @@ cache:
|
||||
|
||||
logger:
|
||||
dev: false
|
||||
path: "" # Empty for stdout, or specify file path
|
||||
path: ""
|
||||
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB in bytes
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -53,5 +62,67 @@ cors:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
|
||||
database:
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
redis:
|
||||
stream_name: "events"
|
||||
consumer_group: "app"
|
||||
max_len: 1000
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "events"
|
||||
storage: "file"
|
||||
max_age: 24h
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "events"
|
||||
poll_interval: 5s
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 1m
|
||||
backoff_factor: 2.0
|
||||
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 5
|
||||
conn_max_lifetime: 30m
|
||||
conn_max_idle_time: 5m
|
||||
retry_attempts: 3
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
type: "pgsql"
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
|
||||
default_orm: "gorm"
|
||||
enable_logging: false
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths:
|
||||
data_dir: "./data"
|
||||
log_dir: "./logs"
|
||||
cache_dir: "./cache"
|
||||
|
||||
extensions: {}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 352 KiB After Width: | Height: | Size: 95 KiB |
54
go.mod
54
go.mod
@@ -13,14 +13,14 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.6.0
|
||||
github.com/klauspost/compress v1.18.0
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/microsoft/go-mssqldb v1.9.5
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
github.com/nats-io/nats.go v1.48.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.40.0
|
||||
@@ -38,13 +38,13 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/driver/sqlserver v1.6.3
|
||||
gorm.io/gorm v1.30.0
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -70,14 +70,14 @@ require (
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -107,30 +107,29 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/prometheus/common v0.67.4 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rs/xid v1.4.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
@@ -138,24 +137,25 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.45.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.4 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.40.1 // indirect
|
||||
modernc.org/sqlite v1.42.2 // indirect
|
||||
)
|
||||
|
||||
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||
|
||||
116
go.sum
116
go.sum
@@ -88,8 +88,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
||||
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
@@ -105,16 +105,17 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -140,8 +141,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||
@@ -157,8 +158,8 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@@ -174,8 +175,8 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
||||
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
|
||||
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
|
||||
@@ -235,14 +236,14 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
@@ -251,16 +252,14 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -291,10 +290,12 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -321,8 +322,8 @@ github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
@@ -356,12 +357,12 @@ go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOV
|
||||
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -376,18 +377,18 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -405,8 +406,10 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -414,8 +417,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -439,8 +442,8 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -456,8 +459,8 @@ golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@@ -472,8 +475,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -482,8 +485,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
@@ -494,8 +497,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
@@ -512,8 +515,9 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
|
||||
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
@@ -528,8 +532,8 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
||||
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
||||
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
@@ -538,8 +542,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
|
||||
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
362
openapi.yaml
362
openapi.yaml
@@ -1,362 +0,0 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: ResolveSpec API
|
||||
version: '1.0'
|
||||
description: A flexible REST API with GraphQL-like capabilities
|
||||
|
||||
servers:
|
||||
- url: 'http://api.example.com/v1'
|
||||
|
||||
paths:
|
||||
'/{schema}/{entity}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
get:
|
||||
summary: Get table metadata
|
||||
description: Retrieve table metadata including columns, types, and relationships
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Response'
|
||||
- type: object
|
||||
properties:
|
||||
data:
|
||||
$ref: '#/components/schemas/TableMetadata'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
post:
|
||||
summary: Perform operations on entities
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
'/{schema}/{entity}/{id}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
summary: Perform operations on a specific entity
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Request:
|
||||
type: object
|
||||
required:
|
||||
- operation
|
||||
properties:
|
||||
operation:
|
||||
type: string
|
||||
enum:
|
||||
- read
|
||||
- create
|
||||
- update
|
||||
- delete
|
||||
id:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: Optional record identifier(s) when not provided in URL
|
||||
data:
|
||||
oneOf:
|
||||
- type: object
|
||||
- type: array
|
||||
items:
|
||||
type: object
|
||||
description: Data for single or bulk create/update operations
|
||||
options:
|
||||
$ref: '#/components/schemas/Options'
|
||||
|
||||
Options:
|
||||
type: object
|
||||
properties:
|
||||
preload:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/PreloadOption'
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
sort:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/SortOption'
|
||||
limit:
|
||||
type: integer
|
||||
minimum: 0
|
||||
offset:
|
||||
type: integer
|
||||
minimum: 0
|
||||
customOperators:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/CustomOperator'
|
||||
computedColumns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ComputedColumn'
|
||||
|
||||
PreloadOption:
|
||||
type: object
|
||||
properties:
|
||||
relation:
|
||||
type: string
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
|
||||
FilterOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- operator
|
||||
- value
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
operator:
|
||||
type: string
|
||||
enum:
|
||||
- eq
|
||||
- neq
|
||||
- gt
|
||||
- gte
|
||||
- lt
|
||||
- lte
|
||||
- like
|
||||
- ilike
|
||||
- in
|
||||
value:
|
||||
type: object
|
||||
|
||||
SortOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- direction
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
direction:
|
||||
type: string
|
||||
enum:
|
||||
- asc
|
||||
- desc
|
||||
|
||||
CustomOperator:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- sql
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
sql:
|
||||
type: string
|
||||
|
||||
ComputedColumn:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- expression
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
expression:
|
||||
type: string
|
||||
|
||||
Response:
|
||||
type: object
|
||||
required:
|
||||
- success
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
data:
|
||||
type: object
|
||||
metadata:
|
||||
$ref: '#/components/schemas/Metadata'
|
||||
error:
|
||||
$ref: '#/components/schemas/Error'
|
||||
|
||||
Metadata:
|
||||
type: object
|
||||
properties:
|
||||
total:
|
||||
type: integer
|
||||
filtered:
|
||||
type: integer
|
||||
limit:
|
||||
type: integer
|
||||
offset:
|
||||
type: integer
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
details:
|
||||
type: object
|
||||
|
||||
TableMetadata:
|
||||
type: object
|
||||
required:
|
||||
- schema
|
||||
- table
|
||||
- columns
|
||||
- relations
|
||||
properties:
|
||||
schema:
|
||||
type: string
|
||||
description: Schema name
|
||||
table:
|
||||
type: string
|
||||
description: Table name
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Column'
|
||||
relations:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: List of relation names
|
||||
|
||||
Column:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- is_nullable
|
||||
- is_primary
|
||||
- is_unique
|
||||
- has_index
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Column name
|
||||
type:
|
||||
type: string
|
||||
description: Data type of the column
|
||||
is_nullable:
|
||||
type: boolean
|
||||
description: Whether the column can contain null values
|
||||
is_primary:
|
||||
type: boolean
|
||||
description: Whether the column is a primary key
|
||||
is_unique:
|
||||
type: boolean
|
||||
description: Whether the column has a unique constraint
|
||||
has_index:
|
||||
type: boolean
|
||||
description: Whether the column is indexed
|
||||
|
||||
responses:
|
||||
BadRequest:
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
ServerError:
|
||||
description: Internal server error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
securitySchemes:
|
||||
bearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
|
||||
security:
|
||||
- bearerAuth: []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,12 +15,16 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewGormAdapter creates a new GORM adapter
|
||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return &GormAdapter{db: db}
|
||||
adapter := &GormAdapter{db: db}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
@@ -40,7 +44,7 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db}
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
@@ -79,7 +83,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx}, nil
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -97,7 +101,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx}
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
@@ -106,12 +110,30 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
return "sqlite"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
@@ -123,7 +145,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(fullTableName)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
@@ -136,7 +159,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
g.db = g.db.Table(table)
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(table)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
|
||||
return g
|
||||
}
|
||||
@@ -322,7 +346,8 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
@@ -360,6 +385,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
}
|
||||
|
||||
@@ -16,12 +16,19 @@ import (
|
||||
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewPgSQLAdapter creates a new PostgreSQL adapter
|
||||
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
|
||||
return &PgSQLAdapter{db: db}
|
||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||
// An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided;
|
||||
// it defaults to "postgres" when omitted.
|
||||
func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
name := "postgres"
|
||||
if len(driverName) > 0 && driverName[0] != "" {
|
||||
name = driverName[0]
|
||||
}
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
@@ -31,22 +38,25 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
values: make(map[string]interface{}),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -56,6 +66,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -98,7 +109,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PgSQLTxAdapter{tx: tx}, nil
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -121,7 +132,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return err
|
||||
}
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx}
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
@@ -141,6 +152,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// preloadConfig represents a relationship to be preloaded
|
||||
type preloadConfig struct {
|
||||
relation string
|
||||
@@ -165,6 +180,7 @@ type PgSQLSelectQuery struct {
|
||||
model interface{}
|
||||
tableName string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
@@ -183,7 +199,9 @@ type PgSQLSelectQuery struct {
|
||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
p.tableAlias = provider.TableAlias()
|
||||
@@ -192,7 +210,8 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -501,16 +520,19 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
||||
|
||||
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
||||
type PgSQLInsertQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
// Extract values from model using reflection
|
||||
// This is a simplified implementation
|
||||
@@ -518,7 +540,8 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -591,6 +614,7 @@ type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
whereClauses []string
|
||||
@@ -602,13 +626,16 @@ type PgSQLUpdateQuery struct {
|
||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.model == nil {
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
@@ -749,6 +776,7 @@ type PgSQLDeleteQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
@@ -756,13 +784,16 @@ type PgSQLDeleteQuery struct {
|
||||
|
||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -835,27 +866,31 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
||||
|
||||
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
||||
type PgSQLTxAdapter struct {
|
||||
tx *sql.Tx
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
tx: p.tx,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
tx: p.tx,
|
||||
values: make(map[string]interface{}),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -865,6 +900,7 @@ func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -912,6 +948,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.tx
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
||||
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
||||
for _, preload := range p.preloads {
|
||||
@@ -1036,9 +1076,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
||||
// Create a new select query for the related table
|
||||
var db common.Database
|
||||
if p.tx != nil {
|
||||
db = &PgSQLTxAdapter{tx: p.tx}
|
||||
db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
|
||||
} else {
|
||||
db = &PgSQLAdapter{db: p.db}
|
||||
db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
|
||||
}
|
||||
|
||||
query := db.NewSelect().
|
||||
|
||||
@@ -11,15 +11,71 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
|
||||
const postgresIdentifierLimit = 63
|
||||
|
||||
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
|
||||
// Returns true if the alias is likely to be truncated
|
||||
func checkAliasLength(relation string) bool {
|
||||
// Bun generates aliases like: parentalias__childalias__columnname
|
||||
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
|
||||
parts := strings.Split(relation, ".")
|
||||
if len(parts) <= 1 {
|
||||
return false // Single level relations are fine
|
||||
}
|
||||
|
||||
// Calculate the actual alias prefix length that Bun will generate
|
||||
// Bun uses double underscores (__) between each relation level
|
||||
// and converts the relation names to lowercase with underscores
|
||||
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
|
||||
aliasPrefixLen := len(aliasPrefix)
|
||||
|
||||
// We need to add 2 more underscores for the column name separator plus column name length
|
||||
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
|
||||
// To be safe, assume the longest column name could be around 35 chars
|
||||
maxColumnNameLen := 35
|
||||
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
|
||||
|
||||
// Check if this would exceed PostgreSQL's identifier limit
|
||||
if estimatedMaxLen > postgresIdentifierLimit {
|
||||
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
|
||||
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check if just the prefix is getting close (within 15 chars of limit)
|
||||
// This gives room for column names
|
||||
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
|
||||
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
|
||||
relation, aliasPrefixLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
// For example: "public.users" -> ("public", "users")
|
||||
//
|
||||
// "users" -> ("", "users")
|
||||
func parseTableName(fullTableName string) (schema, table string) {
|
||||
//
|
||||
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
|
||||
// in the same way as PostgreSQL/MSSQL
|
||||
func parseTableName(fullTableName, driverName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
return fullTableName[:idx], fullTableName[idx+1:]
|
||||
schema = fullTableName[:idx]
|
||||
table = fullTableName[idx+1:]
|
||||
|
||||
// For SQLite, convert schema.table to schema_table
|
||||
if driverName == "sqlite" || driverName == "sqlite3" {
|
||||
table = schema + "_" + table
|
||||
schema = ""
|
||||
}
|
||||
return schema, table
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// CORSConfig holds CORS configuration
|
||||
@@ -15,8 +17,30 @@ type CORSConfig struct {
|
||||
|
||||
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
|
||||
func DefaultCORSConfig() CORSConfig {
|
||||
configManager := config.GetConfigManager()
|
||||
cfg, _ := configManager.GetConfig()
|
||||
hosts := make([]string, 0)
|
||||
// hosts = append(hosts, "*")
|
||||
|
||||
_, _, ipsList := config.GetIPs()
|
||||
|
||||
for i := range cfg.Servers.Instances {
|
||||
server := cfg.Servers.Instances[i]
|
||||
if server.Port == 0 {
|
||||
continue
|
||||
}
|
||||
hosts = append(hosts, server.ExternalURLs...)
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port))
|
||||
for _, ip := range ipsList {
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port))
|
||||
}
|
||||
}
|
||||
|
||||
return CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedOrigins: hosts,
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: GetHeadSpecHeaders(),
|
||||
MaxAge: 86400, // 24 hours
|
||||
@@ -90,11 +114,14 @@ func GetHeadSpecHeaders() []string {
|
||||
}
|
||||
|
||||
// SetCORSHeaders sets CORS headers on a response writer
|
||||
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
func SetCORSHeaders(w ResponseWriter, r Request, config CORSConfig) {
|
||||
// Set allowed origins
|
||||
if len(config.AllowedOrigins) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||
}
|
||||
// if len(config.AllowedOrigins) > 0 {
|
||||
// w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||
// }
|
||||
|
||||
// Todo origin list parsing
|
||||
w.SetHeader("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Set allowed methods
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
@@ -102,9 +129,10 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
}
|
||||
|
||||
// Set allowed headers
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
// if len(config.AllowedHeaders) > 0 {
|
||||
// w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||
// }
|
||||
w.SetHeader("Access-Control-Allow-Headers", "*")
|
||||
|
||||
// Set max age
|
||||
if config.MaxAge > 0 {
|
||||
@@ -115,5 +143,7 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// Expose headers that clients can read
|
||||
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
|
||||
exposeHeaders := config.AllowedHeaders
|
||||
exposeHeaders = append(exposeHeaders, "Content-Range", "X-Api-Range-Total", "X-Api-Range-Size")
|
||||
w.SetHeader("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", "))
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// ValidateAndUnwrapModelResult contains the result of model validation
|
||||
@@ -45,3 +48,216 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
|
||||
OriginalType: originalType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtractTagValue extracts the value for a given key from a struct tag string.
|
||||
// It handles both semicolon and comma-separated tag formats (e.g., GORM and BUN tags).
|
||||
// For tags like "json:name;validate:required" it will extract "name" for key "json".
|
||||
// For tags like "rel:has-many,join:table" it will extract "table" for key "join".
|
||||
func ExtractTagValue(tag, key string) string {
|
||||
// Split by both semicolons and commas to handle different tag formats
|
||||
// We need to be smart about this - commas can be part of values
|
||||
// So we'll try semicolon first, then comma if needed
|
||||
separators := []string{";", ","}
|
||||
|
||||
for _, sep := range separators {
|
||||
parts := strings.Split(tag, sep)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, key+":") {
|
||||
return strings.TrimPrefix(part, key+":")
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetRelationshipInfo analyzes a model type and extracts relationship metadata
|
||||
// for a specific relation field identified by its JSON name.
|
||||
// Returns nil if the field is not found or is not a valid relationship.
|
||||
func GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
|
||||
// Ensure we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
|
||||
if jsonName == relationName {
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
bunTag := field.Tag.Get("bun")
|
||||
info := &RelationshipInfo{
|
||||
FieldName: field.Name,
|
||||
JSONName: jsonName,
|
||||
}
|
||||
|
||||
if strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") {
|
||||
//bun:"rel:has-many,join:rid_hub=rid_hub_division"
|
||||
if strings.Contains(bunTag, "has-many") {
|
||||
info.RelationType = "hasMany"
|
||||
} else if strings.Contains(bunTag, "has-one") {
|
||||
info.RelationType = "hasOne"
|
||||
} else if strings.Contains(bunTag, "belongs-to") {
|
||||
info.RelationType = "belongsTo"
|
||||
} else if strings.Contains(bunTag, "many-to-many") {
|
||||
info.RelationType = "many2many"
|
||||
} else {
|
||||
info.RelationType = "hasOne"
|
||||
}
|
||||
|
||||
// Extract join info
|
||||
joinPart := ExtractTagValue(bunTag, "join")
|
||||
if joinPart != "" && info.RelationType == "many2many" {
|
||||
// For many2many, the join part is the join table name
|
||||
info.JoinTable = joinPart
|
||||
} else if joinPart != "" {
|
||||
// For other relations, parse foreignKey and references
|
||||
joinParts := strings.Split(joinPart, "=")
|
||||
if len(joinParts) == 2 {
|
||||
info.ForeignKey = joinParts[0]
|
||||
info.References = joinParts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// Get related model type
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
elemType := field.Type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// Parse GORM tag to determine relationship type and keys
|
||||
if strings.Contains(gormTag, "foreignKey") {
|
||||
info.ForeignKey = ExtractTagValue(gormTag, "foreignKey")
|
||||
info.References = ExtractTagValue(gormTag, "references")
|
||||
|
||||
// Determine if it's belongsTo or hasMany/hasOne
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
info.RelationType = "hasMany"
|
||||
// Get the element type for slice
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
info.RelationType = "belongsTo"
|
||||
elemType := field.Type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(gormTag, "many2many") {
|
||||
info.RelationType = "many2many"
|
||||
info.JoinTable = ExtractTagValue(gormTag, "many2many")
|
||||
// Get the element type for many2many (always slice)
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Field has no GORM relationship tags, so it's not a relation
|
||||
return nil
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelationPathToBunAlias converts a relation path (e.g., "Order.Customer") to a Bun alias format.
|
||||
// It converts to lowercase and replaces dots with double underscores.
|
||||
// For example: "Order.Customer" -> "order__customer"
|
||||
func RelationPathToBunAlias(relationPath string) string {
|
||||
if relationPath == "" {
|
||||
return ""
|
||||
}
|
||||
// Convert to lowercase and replace dots with double underscores
|
||||
alias := strings.ToLower(relationPath)
|
||||
alias = strings.ReplaceAll(alias, ".", "__")
|
||||
return alias
|
||||
}
|
||||
|
||||
// ReplaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||
// with the appropriate alias for the current preload level.
|
||||
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||
func ReplaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||
return sqlExpr
|
||||
}
|
||||
|
||||
// Replace both quoted and unquoted table references
|
||||
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||
|
||||
// Pattern 1: tablename.column (unquoted)
|
||||
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||
|
||||
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetTableNameFromModel extracts the table name from a model.
|
||||
// It checks the bun tag first, then falls back to converting the struct name to snake_case.
|
||||
func GetTableNameFromModel(model interface{}) string {
|
||||
if model == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers
|
||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Look for bun tag on embedded BaseModel
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if field.Anonymous {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.HasPrefix(bunTag, "table:") {
|
||||
return strings.TrimPrefix(bunTag, "table:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||
return strings.ToLower(modelType.Name())
|
||||
}
|
||||
|
||||
108
pkg/common/handler_utils_test.go
Normal file
108
pkg/common/handler_utils_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractTagValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Extract existing key",
|
||||
tag: "json:name;validate:required",
|
||||
key: "json",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
name: "Extract key with spaces",
|
||||
tag: "json:name ; validate:required",
|
||||
key: "validate",
|
||||
expected: "required",
|
||||
},
|
||||
{
|
||||
name: "Extract key at end",
|
||||
tag: "json:name;validate:required;db:column_name",
|
||||
key: "db",
|
||||
expected: "column_name",
|
||||
},
|
||||
{
|
||||
name: "Extract key at beginning",
|
||||
tag: "primary:true;json:id;db:user_id",
|
||||
key: "primary",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Key not found",
|
||||
tag: "json:name;validate:required",
|
||||
key: "db",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty tag",
|
||||
tag: "",
|
||||
key: "json",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single key-value pair",
|
||||
tag: "json:name",
|
||||
key: "json",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
name: "Key with empty value",
|
||||
tag: "json:;validate:required",
|
||||
key: "json",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Key with complex value",
|
||||
tag: "json:user_name,omitempty;validate:required,min=3",
|
||||
key: "json",
|
||||
expected: "user_name,omitempty",
|
||||
},
|
||||
{
|
||||
name: "Multiple semicolons",
|
||||
tag: "json:name;;validate:required",
|
||||
key: "validate",
|
||||
expected: "required",
|
||||
},
|
||||
{
|
||||
name: "BUN Tag with comma separator",
|
||||
tag: "rel:has-many,join:rid_hub=rid_hub_child",
|
||||
key: "join",
|
||||
expected: "rid_hub=rid_hub_child",
|
||||
},
|
||||
{
|
||||
name: "Extract foreignKey",
|
||||
tag: "foreignKey:UserID;references:ID",
|
||||
key: "foreignKey",
|
||||
expected: "UserID",
|
||||
},
|
||||
{
|
||||
name: "Extract references",
|
||||
tag: "foreignKey:UserID;references:ID",
|
||||
key: "references",
|
||||
expected: "ID",
|
||||
},
|
||||
{
|
||||
name: "Extract many2many",
|
||||
tag: "many2many:user_roles",
|
||||
key: "many2many",
|
||||
expected: "user_roles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractTagValue(tt.tag, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,12 @@ type Database interface {
|
||||
// For Bun, this returns *bun.DB
|
||||
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||
GetUnderlyingDB() interface{}
|
||||
|
||||
// DriverName returns the canonical name of the underlying database driver.
|
||||
// Possible values: "postgres", "sqlite", "mssql", "mysql".
|
||||
// All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's
|
||||
// "sqlserver") to the values above before returning.
|
||||
DriverName() string
|
||||
}
|
||||
|
||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||
|
||||
@@ -20,17 +20,6 @@ type RelationshipInfoProvider interface {
|
||||
GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo
|
||||
}
|
||||
|
||||
// RelationshipInfo contains information about a model relationship
|
||||
type RelationshipInfo struct {
|
||||
FieldName string
|
||||
JSONName string
|
||||
RelationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||
ForeignKey string
|
||||
References string
|
||||
JoinTable string
|
||||
RelatedModel interface{}
|
||||
}
|
||||
|
||||
// NestedCUDProcessor handles recursive processing of nested object graphs
|
||||
type NestedCUDProcessor struct {
|
||||
db Database
|
||||
@@ -85,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType)
|
||||
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||
}
|
||||
|
||||
@@ -108,50 +98,74 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
}
|
||||
|
||||
// Filter regularData to only include fields that exist in the model
|
||||
// Use MapToStruct to validate and filter fields
|
||||
regularData = p.filterValidFields(regularData, model)
|
||||
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
// Get the primary key name for this model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Check if we have any data to process (besides _request)
|
||||
hasData := len(regularData) > 0
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
result.ID = id
|
||||
result.AffectedRows = 1
|
||||
result.Data = regularData
|
||||
// Only perform insert if we have data to insert
|
||||
if hasData {
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
if err != nil {
|
||||
logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err)
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
result.ID = id
|
||||
result.AffectedRows = 1
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations after parent insert (to get parent ID)
|
||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
// Process child relations after parent insert (to get parent ID)
|
||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||
}
|
||||
|
||||
case "update":
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
// Only perform update if we have data to update
|
||||
if hasData {
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err)
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Skipping update for %s - no data columns besides _request", tableName)
|
||||
result.ID = data[pkName]
|
||||
}
|
||||
|
||||
case "delete":
|
||||
// Process child relations first (for referential integrity)
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||
if err != nil {
|
||||
logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err)
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
@@ -159,6 +173,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
result.Data = regularData
|
||||
|
||||
default:
|
||||
logger.Error("Unsupported operation: %s for table=%s", operation, tableName)
|
||||
return nil, fmt.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
|
||||
@@ -176,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
||||
return ""
|
||||
}
|
||||
|
||||
// filterValidFields filters input data to only include fields that exist in the model
|
||||
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
||||
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model to use with MapToStruct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model
|
||||
tempModel := reflect.New(modelType).Interface()
|
||||
|
||||
// Use MapToStruct to map the data - this will only map valid fields
|
||||
err := reflection.MapToStruct(data, tempModel)
|
||||
if err != nil {
|
||||
logger.Debug("Error mapping data to model: %v", err)
|
||||
return data
|
||||
}
|
||||
|
||||
// Extract the mapped fields back into a map
|
||||
// This effectively filters out any fields that don't exist in the model
|
||||
filteredData := make(map[string]interface{})
|
||||
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
||||
|
||||
for key, value := range data {
|
||||
// Check if the field was successfully mapped
|
||||
if fieldWasMapped(tempModelValue, modelType, key) {
|
||||
filteredData[key] = value
|
||||
} else {
|
||||
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredData
|
||||
}
|
||||
|
||||
// fieldWasMapped checks if a field with the given key was mapped to the model
|
||||
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
||||
// Look for the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check bun tag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check lowercase field name
|
||||
if strings.EqualFold(field.Name, key) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
embeddedValue := modelValue.Field(i)
|
||||
if embeddedValue.Kind() == reflect.Ptr {
|
||||
if embeddedValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
embeddedValue = embeddedValue.Elem()
|
||||
}
|
||||
if fieldWasMapped(embeddedValue, fieldType, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||
if len(parentIDs) == 0 {
|
||||
@@ -218,12 +342,13 @@ func (p *NestedCUDProcessor) processInsert(
|
||||
for key, value := range data {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(tableName)
|
||||
// Add RETURNING clause to get the inserted ID
|
||||
query = query.Returning("id")
|
||||
query = query.Returning(pkName)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
|
||||
return nil, fmt.Errorf("insert exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -231,8 +356,8 @@ func (p *NestedCUDProcessor) processInsert(
|
||||
var id interface{}
|
||||
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
|
||||
id = lastID
|
||||
} else if data["id"] != nil {
|
||||
id = data["id"]
|
||||
} else if data[pkName] != nil {
|
||||
id = data[pkName]
|
||||
}
|
||||
|
||||
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
|
||||
@@ -247,6 +372,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
id interface{},
|
||||
) (int64, error) {
|
||||
if id == nil {
|
||||
logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data)
|
||||
return 0, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
@@ -256,6 +382,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err)
|
||||
return 0, fmt.Errorf("update exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -267,6 +394,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
// processDelete handles delete operation
|
||||
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
|
||||
if id == nil {
|
||||
logger.Error("Delete requires an ID: table=%s", tableName)
|
||||
return 0, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
@@ -276,6 +404,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err)
|
||||
return 0, fmt.Errorf("delete exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -292,6 +421,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
relationFields map[string]*RelationshipInfo,
|
||||
relationData map[string]interface{},
|
||||
parentModelType reflect.Type,
|
||||
incomingParentIDs map[string]interface{}, // IDs from all ancestors
|
||||
) error {
|
||||
for relationName, relInfo := range relationFields {
|
||||
relationValue, exists := relationData[relationName]
|
||||
@@ -304,7 +434,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
// Get the related model
|
||||
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||
if !found {
|
||||
logger.Warn("Field %s not found in model", relInfo.FieldName)
|
||||
logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -324,20 +454,89 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
|
||||
|
||||
// Prepare parent IDs for foreign key injection
|
||||
// Start by copying all incoming parent IDs (from ancestors)
|
||||
parentIDs := make(map[string]interface{})
|
||||
if relInfo.ForeignKey != "" {
|
||||
for k, v := range incomingParentIDs {
|
||||
parentIDs[k] = v
|
||||
}
|
||||
logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs)
|
||||
|
||||
// Add the current parent's primary key to the parentIDs map
|
||||
// This ensures nested children have access to all ancestor IDs
|
||||
if parentID != nil && parentModelType != nil {
|
||||
// Get the parent model's primary key field name
|
||||
parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType)
|
||||
if parentPKFieldName != "" {
|
||||
// Get the JSON name for the primary key field
|
||||
parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName)
|
||||
baseName := ""
|
||||
if len(parentPKJSONName) > 1 {
|
||||
baseName = parentPKJSONName
|
||||
} else {
|
||||
// Add parent's PK to the map using the base model name
|
||||
baseName = strings.TrimSuffix(parentPKFieldName, "ID")
|
||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||
if baseName == "" {
|
||||
baseName = "parent"
|
||||
}
|
||||
}
|
||||
|
||||
parentIDs[baseName] = parentID
|
||||
logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Also add the foreign key reference if specified
|
||||
if relInfo.ForeignKey != "" && parentID != nil {
|
||||
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
|
||||
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||
parentIDs[baseName] = parentID
|
||||
// Only add if different from what we already added
|
||||
if _, exists := parentIDs[baseName]; !exists {
|
||||
parentIDs[baseName] = parentID
|
||||
logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs)
|
||||
|
||||
// Determine which field name to use for setting parent ID in child data
|
||||
// Priority: Use foreign key field name if specified
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
}
|
||||
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
|
||||
}
|
||||
|
||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||
if childPKFieldName == "" {
|
||||
childPKFieldName = strings.ToLower(childPKName)
|
||||
}
|
||||
|
||||
logger.Debug("Processing relation with foreignKeyField=%s, childPK=%s", foreignKeyFieldName, childPKFieldName)
|
||||
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object
|
||||
// Single related object - directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
v[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, relatedTableName, operation, parentID, v, err)
|
||||
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||
}
|
||||
|
||||
@@ -345,24 +544,46 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
// Directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item)
|
||||
}
|
||||
}
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
// Directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
|
||||
logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
723
pkg/common/recursive_crud_test.go
Normal file
723
pkg/common/recursive_crud_test.go
Normal file
@@ -0,0 +1,723 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Mock Database for testing
|
||||
type mockDatabase struct {
|
||||
insertCalls []map[string]interface{}
|
||||
updateCalls []map[string]interface{}
|
||||
deleteCalls []interface{}
|
||||
lastID int64
|
||||
}
|
||||
|
||||
func newMockDatabase() *mockDatabase {
|
||||
return &mockDatabase{
|
||||
insertCalls: make([]map[string]interface{}, 0),
|
||||
updateCalls: make([]map[string]interface{}, 0),
|
||||
deleteCalls: make([]interface{}, 0),
|
||||
lastID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} }
|
||||
func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} }
|
||||
func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} }
|
||||
func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} }
|
||||
func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) {
|
||||
return m, nil
|
||||
}
|
||||
func (m *mockDatabase) CommitTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) RollbackTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) GetUnderlyingDB() interface{} {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// Mock SelectQuery
|
||||
type mockSelectQuery struct{}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Table(name string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Order(order string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Limit(n int) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Offset(n int) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Group(group string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil }
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil }
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil }
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil }
|
||||
|
||||
// Mock InsertQuery
|
||||
type mockInsertQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
values map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Table(name string) InsertQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery {
|
||||
if m.values == nil {
|
||||
m.values = make(map[string]interface{})
|
||||
}
|
||||
m.values[column] = value
|
||||
return m
|
||||
}
|
||||
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the insert call
|
||||
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
||||
m.db.lastID++
|
||||
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock UpdateQuery
|
||||
type mockUpdateQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
setValues map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Table(name string) UpdateQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery {
|
||||
m.setValues = values
|
||||
return m
|
||||
}
|
||||
func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the update call
|
||||
m.db.updateCalls = append(m.db.updateCalls, m.setValues)
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock DeleteQuery
|
||||
type mockDeleteQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
}
|
||||
|
||||
func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m }
|
||||
func (m *mockDeleteQuery) Table(name string) DeleteQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m }
|
||||
func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the delete call
|
||||
m.db.deleteCalls = append(m.db.deleteCalls, m.table)
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock Result
|
||||
type mockResult struct {
|
||||
lastID int64
|
||||
rowsAffected int64
|
||||
}
|
||||
|
||||
func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil }
|
||||
func (m *mockResult) RowsAffected() int64 { return m.rowsAffected }
|
||||
|
||||
// Mock ModelRegistry
|
||||
type mockModelRegistry struct{}
|
||||
|
||||
func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil }
|
||||
func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil }
|
||||
func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil }
|
||||
func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) }
|
||||
|
||||
// Mock RelationshipInfoProvider
|
||||
type mockRelationshipProvider struct {
|
||||
relationships map[string]*RelationshipInfo
|
||||
}
|
||||
|
||||
func newMockRelationshipProvider() *mockRelationshipProvider {
|
||||
return &mockRelationshipProvider{
|
||||
relationships: make(map[string]*RelationshipInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
|
||||
key := modelType.Name() + "." + relationName
|
||||
return m.relationships[key]
|
||||
}
|
||||
|
||||
func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) {
|
||||
key := modelTypeName + "." + relationName
|
||||
m.relationships[key] = info
|
||||
}
|
||||
|
||||
// Test Models
|
||||
type Department struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name"`
|
||||
Employees []*Employee `json:"employees,omitempty"`
|
||||
}
|
||||
|
||||
func (d Department) TableName() string { return "departments" }
|
||||
func (d Department) GetIDName() string { return "ID" }
|
||||
|
||||
type Employee struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name"`
|
||||
DepartmentID int64 `json:"department_id"`
|
||||
Tasks []*Task `json:"tasks,omitempty"`
|
||||
}
|
||||
|
||||
func (e Employee) TableName() string { return "employees" }
|
||||
func (e Employee) GetIDName() string { return "ID" }
|
||||
|
||||
type Task struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Title string `json:"title"`
|
||||
EmployeeID int64 `json:"employee_id"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
}
|
||||
|
||||
func (t Task) TableName() string { return "tasks" }
|
||||
func (t Task) GetIDName() string { return "ID" }
|
||||
|
||||
type Comment struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Text string `json:"text"`
|
||||
TaskID int64 `json:"task_id"`
|
||||
}
|
||||
|
||||
func (c Comment) TableName() string { return "comments" }
|
||||
func (c Comment) GetIDName() string { return "ID" }
|
||||
|
||||
// Test Cases
|
||||
|
||||
func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register Department -> Employees relationship
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"name": "Jane Smith",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID == nil {
|
||||
t.Error("Expected result.ID to be set")
|
||||
}
|
||||
|
||||
// Verify department was inserted
|
||||
if len(db.insertCalls) != 3 {
|
||||
t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify first insert is department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employees were inserted with foreign key
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
if db.insertCalls[2]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register relationships
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||
FieldName: "Tasks",
|
||||
JSONName: "tasks",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "EmployeeID",
|
||||
RelatedModel: Task{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"tasks": []interface{}{
|
||||
map[string]interface{}{
|
||||
"title": "Task 1",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"title": "Task 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID == nil {
|
||||
t.Error("Expected result.ID to be set")
|
||||
}
|
||||
|
||||
// Verify: 1 dept + 1 employee + 2 tasks = 4 inserts
|
||||
if len(db.insertCalls) != 4 {
|
||||
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employee has department_id
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
|
||||
// Verify tasks have employee_id
|
||||
if db.insertCalls[2]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id set")
|
||||
}
|
||||
if db.insertCalls[3]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "update",
|
||||
"ID": int64(10), // Use capital ID to match struct field
|
||||
"name": "John Updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify department was inserted (1 insert)
|
||||
// Employee should be updated (1 update)
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
if len(db.updateCalls) != 1 {
|
||||
t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls))
|
||||
}
|
||||
|
||||
// Verify update data
|
||||
if db.updateCalls[0]["name"] != "John Updated" {
|
||||
t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
// Data with only _request field for nested employee
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "insert",
|
||||
// No other fields besides _request
|
||||
// Note: Foreign key will be injected, so employee WILL be inserted
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Department + Employee (with injected FK) = 2 inserts
|
||||
if len(db.insertCalls) != 2 {
|
||||
t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employee has foreign key
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id injected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_Update(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ID": int64(1), // Use capital ID to match struct field
|
||||
"name": "Engineering Updated",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "insert",
|
||||
"name": "New Employee",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"update",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID != int64(1) {
|
||||
t.Errorf("Expected result.ID to be 1, got %v", result.ID)
|
||||
}
|
||||
|
||||
// Verify department was updated
|
||||
if len(db.updateCalls) != 1 {
|
||||
t.Errorf("Expected 1 update call, got %d", len(db.updateCalls))
|
||||
}
|
||||
|
||||
// Verify new employee was inserted
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_Delete(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ID": int64(1), // Use capital ID to match struct field
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "delete",
|
||||
"ID": int64(10), // Use capital ID
|
||||
},
|
||||
map[string]interface{}{
|
||||
"_request": "delete",
|
||||
"ID": int64(11), // Use capital ID
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"delete",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify employees were deleted first, then department
|
||||
// 2 employees + 1 department = 3 deletes
|
||||
if len(db.deleteCalls) != 3 {
|
||||
t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register 3-level relationships
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||
FieldName: "Tasks",
|
||||
JSONName: "tasks",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "EmployeeID",
|
||||
RelatedModel: Task{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{
|
||||
FieldName: "Comments",
|
||||
JSONName: "comments",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "TaskID",
|
||||
RelatedModel: Comment{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John",
|
||||
"tasks": []interface{}{
|
||||
map[string]interface{}{
|
||||
"title": "Task 1",
|
||||
"comments": []interface{}{
|
||||
map[string]interface{}{
|
||||
"text": "Great work!",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts
|
||||
if len(db.insertCalls) != 4 {
|
||||
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Error("Expected department to be inserted first")
|
||||
}
|
||||
|
||||
// Verify employee has department_id
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id")
|
||||
}
|
||||
|
||||
// Verify task has employee_id
|
||||
if db.insertCalls[2]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id")
|
||||
}
|
||||
|
||||
// Verify comment has task_id
|
||||
if db.insertCalls[3]["task_id"] == nil {
|
||||
t.Error("Expected comment to have task_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectForeignKeys(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "John",
|
||||
}
|
||||
|
||||
parentIDs := map[string]interface{}{
|
||||
"department": int64(5),
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(Employee{})
|
||||
|
||||
processor.injectForeignKeys(data, modelType, parentIDs)
|
||||
|
||||
// Should inject department_id based on the "department" key in parentIDs
|
||||
if data["department_id"] == nil {
|
||||
t.Error("Expected department_id to be injected")
|
||||
}
|
||||
|
||||
if data["department_id"] != int64(5) {
|
||||
t.Errorf("Expected department_id to be 5, got %v", data["department_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyName(t *testing.T) {
|
||||
dept := Department{}
|
||||
pkName := reflection.GetPrimaryKeyName(dept)
|
||||
|
||||
if pkName != "ID" {
|
||||
t.Errorf("Expected primary key name 'ID', got '%s'", pkName)
|
||||
}
|
||||
|
||||
// Test with pointer
|
||||
pkName2 := reflection.GetPrimaryKeyName(&dept)
|
||||
if pkName2 != "ID" {
|
||||
t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2)
|
||||
}
|
||||
}
|
||||
@@ -130,6 +130,9 @@ func validateWhereClauseSecurity(where string) error {
|
||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||
//
|
||||
// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators
|
||||
// to prevent OR logic from escaping and affecting the entire query incorrectly.
|
||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
@@ -143,8 +146,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
// Check if the original clause has outer parentheses and contains OR operators
|
||||
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
|
||||
hasOuterParens := false
|
||||
if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' {
|
||||
_, hasOuterParens = stripOneMatchingOuterParen(where)
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim for processing
|
||||
whereWithoutParens := stripOuterParentheses(where)
|
||||
shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens)
|
||||
|
||||
// Use the stripped version for processing
|
||||
where = whereWithoutParens
|
||||
|
||||
// Get valid columns from the model if tableName is provided
|
||||
var validColumns map[string]bool
|
||||
@@ -166,6 +180,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Add join aliases as allowed prefixes
|
||||
for _, alias := range options[0].JoinAliases {
|
||||
if alias != "" {
|
||||
allowedPrefixes[alias] = true
|
||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
@@ -221,7 +243,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
|
||||
result := strings.Join(validConditions, " AND ")
|
||||
|
||||
if result != where {
|
||||
// If the original clause had outer parentheses and contains OR operators,
|
||||
// restore the outer parentheses to prevent OR logic from escaping
|
||||
if shouldPreserveParens {
|
||||
result = "(" + result + ")"
|
||||
logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result)
|
||||
}
|
||||
|
||||
if result != where && !shouldPreserveParens {
|
||||
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||
}
|
||||
|
||||
@@ -282,6 +311,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) {
|
||||
return strings.TrimSpace(s[1 : len(s)-1]), true
|
||||
}
|
||||
|
||||
// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses
|
||||
// to prevent OR logic from escaping. It checks if the clause already has
|
||||
// matching outer parentheses and only adds them if they don't exist.
|
||||
//
|
||||
// This is particularly important for OR conditions and complex filters where
|
||||
// the absence of parentheses could cause the logic to escape and affect
|
||||
// the entire query incorrectly.
|
||||
//
|
||||
// Parameters:
|
||||
// - clause: The SQL clause to check and potentially wrap
|
||||
//
|
||||
// Returns:
|
||||
// - The clause with guaranteed outer parentheses, or empty string if input is empty
|
||||
func EnsureOuterParentheses(clause string) string {
|
||||
if clause == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
clause = strings.TrimSpace(clause)
|
||||
if clause == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if the clause already has matching outer parentheses
|
||||
_, hasOuterParens := stripOneMatchingOuterParen(clause)
|
||||
|
||||
// If it already has matching outer parentheses, return as-is
|
||||
if hasOuterParens {
|
||||
return clause
|
||||
}
|
||||
|
||||
// Otherwise, wrap it in parentheses
|
||||
return "(" + clause + ")"
|
||||
}
|
||||
|
||||
// containsTopLevelOR checks if a SQL clause contains OR operators at the top level
|
||||
// (i.e., not inside parentheses or subqueries). This is used to determine if
|
||||
// outer parentheses should be preserved to prevent OR logic from escaping.
|
||||
func containsTopLevelOR(clause string) bool {
|
||||
if clause == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
lowerClause := strings.ToLower(clause)
|
||||
|
||||
for i := 0; i < len(clause); i++ {
|
||||
ch := clause[i]
|
||||
|
||||
// Track quote state
|
||||
if ch == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if ch == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if inside quotes
|
||||
if inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track parenthesis depth
|
||||
switch ch {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
}
|
||||
|
||||
// Only check for OR at depth 0 (not inside parentheses)
|
||||
if depth == 0 && i+4 <= len(clause) {
|
||||
// Check for " OR " (case-insensitive)
|
||||
substring := lowerClause[i : i+4]
|
||||
if substring == " or " {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||
func splitByAND(where string) []string {
|
||||
|
||||
103
pkg/common/sql_helpers_tablename_test.go
Normal file
103
pkg/common/sql_helpers_tablename_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
|
||||
// are correctly handled when the tableName parameter matches the prefix
|
||||
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
options *RequestOptions
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Correct table prefix should not be changed",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Wrong table prefix should be fixed",
|
||||
where: "wrong_table.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Relation name should not replace correct table prefix",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "mastertaskitem",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Unqualified column should remain unqualified",
|
||||
where: "rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
|
||||
tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
|
||||
// are correctly added to unqualified columns
|
||||
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Add prefix to unqualified column",
|
||||
where: "rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Don't change already qualified column",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Don't change qualified column with different table",
|
||||
where: "other_table.rid_something is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "other_table.rid_something is null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
|
||||
tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureOuterParentheses(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no parentheses",
|
||||
input: "status = 'active'",
|
||||
expected: "(status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "already has outer parentheses",
|
||||
input: "(status = 'active')",
|
||||
expected: "(status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "OR condition without parentheses",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: "(status = 'active' OR status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "OR condition with parentheses",
|
||||
input: "(status = 'active' OR status = 'pending')",
|
||||
expected: "(status = 'active' OR status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "complex condition with nested parentheses",
|
||||
input: "(status = 'active' OR status = 'pending') AND (age > 18)",
|
||||
expected: "((status = 'active' OR status = 'pending') AND (age > 18))",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mismatched parentheses - adds outer ones",
|
||||
input: "(status = 'active' OR status = 'pending'",
|
||||
expected: "((status = 'active' OR status = 'pending')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := EnsureOuterParentheses(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsTopLevelOR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no OR operator",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "top-level OR",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OR inside parentheses",
|
||||
input: "age > 18 AND (status = 'active' OR status = 'pending')",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "OR in subquery",
|
||||
input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "OR inside quotes",
|
||||
input: "comment = 'this OR that'",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "mixed - top-level OR and nested OR",
|
||||
input: "name = 'test' OR (status = 'active' OR status = 'pending')",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "lowercase or",
|
||||
input: "status = 'active' or status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase OR",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsTopLevelOR(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "OR condition with outer parentheses - preserved",
|
||||
where: "(status = 'active' OR status = 'pending')",
|
||||
tableName: "users",
|
||||
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "AND condition with outer parentheses - stripped (no OR)",
|
||||
where: "(status = 'active' AND age > 18)",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "complex OR with nested conditions",
|
||||
where: "((status = 'active' OR status = 'pending') AND age > 18)",
|
||||
tableName: "users",
|
||||
// Outer parens are stripped, but inner parens with OR are preserved
|
||||
expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause",
|
||||
where: "status = 'active' OR status = 'pending'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' OR users.status = 'pending'",
|
||||
},
|
||||
{
|
||||
name: "simple OR with parentheses - preserved",
|
||||
where: "(users.status = 'active' OR users.status = 'pending')",
|
||||
tableName: "users",
|
||||
// Already has correct prefixes, parentheses preserved
|
||||
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -23,6 +23,10 @@ type RequestOptions struct {
|
||||
CursorForward string `json:"cursor_forward"`
|
||||
CursorBackward string `json:"cursor_backward"`
|
||||
FetchRowNumber *string `json:"fetch_row_number"`
|
||||
|
||||
// Join table aliases (used for validation of prefixed columns in filters/sorts)
|
||||
// Not serialized to JSON as it's internal validation state
|
||||
JoinAliases []string `json:"-"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
@@ -33,6 +37,7 @@ type Parameter struct {
|
||||
|
||||
type PreloadOption struct {
|
||||
Relation string `json:"relation"`
|
||||
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
@@ -45,9 +50,14 @@ type PreloadOption struct {
|
||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||
|
||||
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
|
||||
|
||||
// Custom SQL JOINs from XFiles - used when preload needs additional joins
|
||||
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
|
||||
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
@@ -111,3 +121,14 @@ type TableMetadata struct {
|
||||
Columns []Column `json:"columns"`
|
||||
Relations []string `json:"relations"`
|
||||
}
|
||||
|
||||
// RelationshipInfo contains information about a model relationship
|
||||
type RelationshipInfo struct {
|
||||
FieldName string `json:"field_name"`
|
||||
JSONName string `json:"json_name"`
|
||||
RelationType string `json:"relation_type"` // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||
ForeignKey string `json:"foreign_key"`
|
||||
References string `json:"references"`
|
||||
JoinTable string `json:"join_table"`
|
||||
RelatedModel interface{} `json:"related_model"`
|
||||
}
|
||||
|
||||
@@ -237,15 +237,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
foundJoin := false
|
||||
for _, j := range options.JoinAliases {
|
||||
if strings.Contains(sort.Column, j) {
|
||||
foundJoin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundJoin {
|
||||
validSorts = append(validSorts, sort)
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
}
|
||||
}
|
||||
filtered.Sort = validSorts
|
||||
@@ -258,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||
filteredPreload.SqlJoins = preload.SqlJoins
|
||||
filteredPreload.JoinAliases = preload.JoinAliases
|
||||
|
||||
// Filter preload filters
|
||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||
for _, filter := range preload.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
// Check if the filter column references a joined table alias
|
||||
foundJoin := false
|
||||
for _, alias := range preload.JoinAliases {
|
||||
if strings.Contains(filter.Column, alias) {
|
||||
foundJoin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundJoin {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
}
|
||||
}
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
@@ -291,6 +321,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
|
||||
// Clear JoinAliases - this is an internal validation field and should not be persisted
|
||||
filtered.JoinAliases = nil
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
|
||||
@@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Columns: []string{"id", "name"},
|
||||
// Set JoinAliases - this should be cleared by FilterRequestOptions
|
||||
JoinAliases: []string{"d", "u", "r"},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Verify that JoinAliases was cleared (internal field should not persist)
|
||||
if filtered.JoinAliases != nil {
|
||||
t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases)
|
||||
}
|
||||
|
||||
// Verify that other fields are still properly filtered
|
||||
if len(filtered.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeSortExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -73,6 +73,9 @@ type ServerInstanceConfig struct {
|
||||
|
||||
// Tags for organization and filtering
|
||||
Tags map[string]string `mapstructure:"tags"`
|
||||
|
||||
// ExternalURLs are additional URLs that this server instance is accessible from (for CORS) for proxy setups
|
||||
ExternalURLs []string `mapstructure:"external_urls"`
|
||||
}
|
||||
|
||||
// TracingConfig holds OpenTelemetry tracing configuration
|
||||
|
||||
@@ -12,6 +12,16 @@ type Manager struct {
|
||||
v *viper.Viper
|
||||
}
|
||||
|
||||
var configInstance *Manager
|
||||
|
||||
// GetConfigManager returns a singleton configuration manager instance
|
||||
func GetConfigManager() *Manager {
|
||||
if configInstance == nil {
|
||||
configInstance = NewManager()
|
||||
}
|
||||
return configInstance
|
||||
}
|
||||
|
||||
// NewManager creates a new configuration manager with defaults
|
||||
func NewManager() *Manager {
|
||||
v := viper.New()
|
||||
@@ -32,7 +42,8 @@ func NewManager() *Manager {
|
||||
// Set default values
|
||||
setDefaults(v)
|
||||
|
||||
return &Manager{v: v}
|
||||
configInstance = &Manager{v: v}
|
||||
return configInstance
|
||||
}
|
||||
|
||||
// NewManagerWithOptions creates a new configuration manager with custom options
|
||||
|
||||
@@ -2,6 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ApplyGlobalDefaults applies global server defaults to this instance
|
||||
@@ -105,3 +108,42 @@ func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) {
|
||||
|
||||
return &instance, nil
|
||||
}
|
||||
|
||||
// GetIPs - GetIP for pc
|
||||
func GetIPs() (hostname string, ipList string, ipNetList []net.IP) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println("Recovered in GetIPs", err)
|
||||
}
|
||||
}()
|
||||
hostname, _ = os.Hostname()
|
||||
ipaddrlist := make([]net.IP, 0)
|
||||
iplist := ""
|
||||
addrs, err := net.LookupIP(hostname)
|
||||
if err != nil {
|
||||
return hostname, iplist, ipaddrlist
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
// cfg.LogInfo("\nFound IP Host Address: %s", a)
|
||||
if strings.Contains(a.String(), "127.0.0.1") {
|
||||
continue
|
||||
}
|
||||
iplist = fmt.Sprintf("%s,%s", iplist, a)
|
||||
ipaddrlist = append(ipaddrlist, a)
|
||||
}
|
||||
if iplist == "" {
|
||||
iff, _ := net.InterfaceAddrs()
|
||||
for _, a := range iff {
|
||||
// cfg.LogInfo("\nFound IP Address: %s", a)
|
||||
if strings.Contains(a.String(), "127.0.0.1") {
|
||||
continue
|
||||
}
|
||||
iplist = fmt.Sprintf("%s,%s", iplist, a)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
iplist = strings.TrimLeft(iplist, ",")
|
||||
return hostname, iplist, ipaddrlist
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
|
||||
- **GORM** - Popular Go ORM
|
||||
- **Native** - Standard library `*sql.DB`
|
||||
- All three share the same underlying connection pool
|
||||
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
|
||||
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||
- **Production-Ready Features**:
|
||||
- Automatic health checks and reconnection
|
||||
@@ -179,6 +180,35 @@ if err != nil {
|
||||
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||
```
|
||||
|
||||
#### Cross-Database Example with SQLite
|
||||
|
||||
```go
|
||||
// Same model works across all databases
|
||||
type User struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Username string `bun:"username"`
|
||||
Email string `bun:"email"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "auth.users"
|
||||
}
|
||||
|
||||
// PostgreSQL connection
|
||||
pgConn, _ := mgr.Get("primary")
|
||||
pgDB, _ := pgConn.Bun()
|
||||
var pgUsers []User
|
||||
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth.users
|
||||
|
||||
// SQLite connection
|
||||
sqliteConn, _ := mgr.Get("cache-db")
|
||||
sqliteDB, _ := sqliteConn.Bun()
|
||||
var sqliteUsers []User
|
||||
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
|
||||
```
|
||||
|
||||
#### Use MongoDB
|
||||
|
||||
```go
|
||||
@@ -368,6 +398,37 @@ Providers handle:
|
||||
- Connection statistics
|
||||
- Connection cleanup
|
||||
|
||||
### SQLite Schema Handling
|
||||
|
||||
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
|
||||
|
||||
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
|
||||
|
||||
```go
|
||||
// Model definition (works across all databases)
|
||||
func (User) TableName() string {
|
||||
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
|
||||
// SQLite: "auth_users"
|
||||
}
|
||||
|
||||
// Query execution
|
||||
db.NewSelect().Model(&User{}).Scan(ctx)
|
||||
// PostgreSQL/MSSQL: SELECT * FROM auth.users
|
||||
// SQLite: SELECT * FROM auth_users
|
||||
```
|
||||
|
||||
**How it Works**:
|
||||
- Bun, GORM, and Native adapters detect the driver type
|
||||
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
|
||||
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
|
||||
- Preload and relation queries are also handled automatically
|
||||
|
||||
**Benefits**:
|
||||
- Write database-agnostic code
|
||||
- Use the same models across PostgreSQL, MSSQL, and SQLite
|
||||
- No conditional logic needed in your application
|
||||
- Schema separation maintained through naming convention in SQLite
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||
|
||||
@@ -128,7 +128,7 @@ func DefaultManagerConfig() ManagerConfig {
|
||||
RetryAttempts: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
RetryMaxDelay: 10 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
}
|
||||
@@ -161,6 +161,11 @@ func (c *ManagerConfig) ApplyDefaults() {
|
||||
if c.HealthCheckInterval == 0 {
|
||||
c.HealthCheckInterval = defaults.HealthCheckInterval
|
||||
}
|
||||
// EnableAutoReconnect defaults to true - apply if not explicitly set
|
||||
// Since this is a boolean, we apply the default unconditionally when it's false
|
||||
if !c.EnableAutoReconnect {
|
||||
c.EnableAutoReconnect = defaults.EnableAutoReconnect
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the manager configuration
|
||||
@@ -216,7 +221,10 @@ func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
|
||||
cc.ConnectTimeout = 10 * time.Second
|
||||
}
|
||||
if cc.QueryTimeout == 0 {
|
||||
cc.QueryTimeout = 30 * time.Second
|
||||
cc.QueryTimeout = 2 * time.Minute // Default to 2 minutes
|
||||
} else if cc.QueryTimeout < 2*time.Minute {
|
||||
// Enforce minimum of 2 minutes
|
||||
cc.QueryTimeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
// Default ORM
|
||||
@@ -320,14 +328,29 @@ func (cc *ConnectionConfig) buildPostgresDSN() string {
|
||||
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
|
||||
}
|
||||
|
||||
// Add statement_timeout for query execution timeout (in milliseconds)
|
||||
if cc.QueryTimeout > 0 {
|
||||
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||
dsn += fmt.Sprintf(" statement_timeout=%d", timeoutMs)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildSQLiteDSN() string {
|
||||
if cc.FilePath != "" {
|
||||
return cc.FilePath
|
||||
filepath := cc.FilePath
|
||||
if filepath == "" {
|
||||
filepath = ":memory:"
|
||||
}
|
||||
return ":memory:"
|
||||
|
||||
// Add query parameters for timeouts
|
||||
// Note: SQLite driver supports _timeout parameter (in milliseconds)
|
||||
if cc.QueryTimeout > 0 {
|
||||
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||
filepath += fmt.Sprintf("?_timeout=%d", timeoutMs)
|
||||
}
|
||||
|
||||
return filepath
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||
@@ -339,6 +362,24 @@ func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
|
||||
}
|
||||
|
||||
// Add connection timeout (in seconds)
|
||||
if cc.ConnectTimeout > 0 {
|
||||
timeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&connection timeout=%d", timeoutSec)
|
||||
}
|
||||
|
||||
// Add dial timeout for TCP connection (in seconds)
|
||||
if cc.ConnectTimeout > 0 {
|
||||
dialTimeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&dial timeout=%d", dialTimeoutSec)
|
||||
}
|
||||
|
||||
// Add read timeout (in seconds) - enforces timeout for reading data
|
||||
if cc.QueryTimeout > 0 {
|
||||
readTimeoutSec := int(cc.QueryTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&read timeout=%d", readTimeoutSec)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
|
||||
@@ -372,12 +372,20 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
bunDB, err := c.Bun()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Double-check bunDB exists (while already holding write lock)
|
||||
if c.bunDB == nil {
|
||||
// Get native connection first
|
||||
native, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get bun", err)
|
||||
}
|
||||
|
||||
// Create Bun DB wrapping the same sql.DB
|
||||
dialect := c.getBunDialect()
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(bunDB)
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
@@ -400,12 +408,25 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
gormDB, err := c.GORM()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Double-check gormDB exists (while already holding write lock)
|
||||
if c.gormDB == nil {
|
||||
// Get native connection first
|
||||
native, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get gorm", err)
|
||||
}
|
||||
|
||||
// Create GORM DB wrapping the same sql.DB
|
||||
dialector := c.getGORMDialector(native)
|
||||
db, err := gorm.Open(dialector, &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "initialize gorm", err)
|
||||
}
|
||||
|
||||
c.gormDB = db
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(gormDB)
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
@@ -428,21 +449,29 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
return c.nativeAdapter, nil
|
||||
}
|
||||
|
||||
native, err := c.Native()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Double-check nativeDB exists (while already holding write lock)
|
||||
if c.nativeDB == nil {
|
||||
if !c.connected {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Get native connection from provider
|
||||
db, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get native", err)
|
||||
}
|
||||
|
||||
c.nativeDB = db
|
||||
}
|
||||
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeSQLite:
|
||||
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeMSSQL:
|
||||
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
@@ -49,3 +50,18 @@ func createProvider(dbType DatabaseType) (Provider, error) {
|
||||
// Provider is an alias to the providers.Provider interface
|
||||
// This allows dbmanager package consumers to use Provider without importing providers
|
||||
type Provider = providers.Provider
|
||||
|
||||
// NewConnectionFromDB creates a new Connection from an existing *sql.DB
|
||||
// This allows you to use dbmanager features (ORM wrappers, health checks, etc.)
|
||||
// with a database connection that was opened outside of dbmanager
|
||||
//
|
||||
// Parameters:
|
||||
// - name: A unique name for this connection
|
||||
// - dbType: The database type (DatabaseTypePostgreSQL, DatabaseTypeSQLite, or DatabaseTypeMSSQL)
|
||||
// - db: An existing *sql.DB connection
|
||||
//
|
||||
// Returns a Connection that wraps the existing *sql.DB
|
||||
func NewConnectionFromDB(name string, dbType DatabaseType, db *sql.DB) Connection {
|
||||
provider := providers.NewExistingDBProvider(db, name)
|
||||
return newSQLConnection(name, dbType, ConnectionConfig{Name: name, Type: dbType}, provider)
|
||||
}
|
||||
|
||||
210
pkg/dbmanager/factory_test.go
Normal file
210
pkg/dbmanager/factory_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
// Open a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a connection from the existing database
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
if conn == nil {
|
||||
t.Fatal("Expected connection to be created")
|
||||
}
|
||||
|
||||
// Verify connection properties
|
||||
if conn.Name() != "test-connection" {
|
||||
t.Errorf("Expected name 'test-connection', got '%s'", conn.Name())
|
||||
}
|
||||
|
||||
if conn.Type() != DatabaseTypeSQLite {
|
||||
t.Errorf("Expected type DatabaseTypeSQLite, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Connect(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect should verify the existing connection works
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Native(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get native DB
|
||||
nativeDB, err := conn.Native()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Native to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if nativeDB != db {
|
||||
t.Error("Expected Native to return the same database instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Bun(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get Bun ORM
|
||||
bunDB, err := conn.Bun()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Bun to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if bunDB == nil {
|
||||
t.Error("Expected Bun to return a non-nil instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_GORM(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get GORM
|
||||
gormDB, err := conn.GORM()
|
||||
if err != nil {
|
||||
t.Errorf("Expected GORM to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if gormDB == nil {
|
||||
t.Error("Expected GORM to return a non-nil instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_HealthCheck(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Health check should succeed
|
||||
err = conn.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Stats(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stats := conn.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.Name != "test-connection" {
|
||||
t.Errorf("Expected stats.Name to be 'test-connection', got '%s'", stats.Name)
|
||||
}
|
||||
|
||||
if stats.Type != DatabaseTypeSQLite {
|
||||
t.Errorf("Expected stats.Type to be DatabaseTypeSQLite, got '%s'", stats.Type)
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected stats.Connected to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
// This test just verifies the factory works with PostgreSQL type
|
||||
// It won't actually connect since we're using SQLite
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-pg", DatabaseTypePostgreSQL, db)
|
||||
if conn == nil {
|
||||
t.Fatal("Expected connection to be created")
|
||||
}
|
||||
|
||||
if conn.Type() != DatabaseTypePostgreSQL {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
@@ -219,9 +219,10 @@ func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
|
||||
}
|
||||
|
||||
// Start background health checks if enabled
|
||||
if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 {
|
||||
// Always start background health checks
|
||||
if m.config.HealthCheckInterval > 0 {
|
||||
m.startHealthChecker()
|
||||
logger.Info("Background health checker started: interval=%v", m.config.HealthCheckInterval)
|
||||
}
|
||||
|
||||
logger.Info("Database manager initialized: connections=%d", len(m.connections))
|
||||
@@ -230,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
|
||||
// Close closes all database connections
|
||||
func (m *connectionManager) Close() error {
|
||||
// Stop the health checker before taking mu. performHealthCheck acquires
|
||||
// a read lock, so waiting for the goroutine while holding the write lock
|
||||
// would deadlock.
|
||||
m.stopHealthChecker()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Stop health checker
|
||||
m.stopHealthChecker()
|
||||
|
||||
// Close all connections
|
||||
var errors []error
|
||||
for name, conn := range m.connections {
|
||||
|
||||
226
pkg/dbmanager/manager_test.go
Normal file
226
pkg/dbmanager/manager_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create manager config with a short health check interval for testing
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 1 * time.Second, // Short interval for testing
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
|
||||
// Create manager
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// Connect - this should start the background health checker
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer mgr.Close()
|
||||
|
||||
// Get the connection to verify it's healthy
|
||||
conn, err := mgr.Get("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial health check
|
||||
err = conn.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Initial health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for a few health check cycles
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Get stats to verify the connection is still healthy
|
||||
stats := conn.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected connection to still be connected")
|
||||
}
|
||||
|
||||
if stats.HealthCheckStatus == "" {
|
||||
t.Error("Expected health check status to be set")
|
||||
}
|
||||
|
||||
// Verify the manager has started the health checker
|
||||
if cm, ok := mgr.(*connectionManager); ok {
|
||||
if cm.healthTicker == nil {
|
||||
t.Error("Expected health ticker to be running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultHealthCheckInterval(t *testing.T) {
|
||||
// Verify the default health check interval is 15 seconds
|
||||
defaults := DefaultManagerConfig()
|
||||
|
||||
expectedInterval := 15 * time.Second
|
||||
if defaults.HealthCheckInterval != expectedInterval {
|
||||
t.Errorf("Expected default health check interval to be %v, got %v",
|
||||
expectedInterval, defaults.HealthCheckInterval)
|
||||
}
|
||||
|
||||
if !defaults.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDefaultsEnablesAutoReconnect(t *testing.T) {
|
||||
// Create a config without setting EnableAutoReconnect
|
||||
cfg := ManagerConfig{
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Verify it's false initially (Go's zero value for bool)
|
||||
if cfg.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be false before ApplyDefaults")
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// Verify it's now true
|
||||
if !cfg.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be true after ApplyDefaults")
|
||||
}
|
||||
|
||||
// Verify health check interval is also set
|
||||
if cfg.HealthCheckInterval != 15*time.Second {
|
||||
t.Errorf("Expected health check interval to be 15s, got %v", cfg.HealthCheckInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerHealthCheck(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create manager config
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
|
||||
// Create and connect manager
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer mgr.Close()
|
||||
|
||||
// Perform health check on all connections
|
||||
err = mgr.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := mgr.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.TotalConnections != 1 {
|
||||
t.Errorf("Expected 1 total connection, got %d", stats.TotalConnections)
|
||||
}
|
||||
|
||||
if stats.HealthyCount != 1 {
|
||||
t.Errorf("Expected 1 healthy connection, got %d", stats.HealthyCount)
|
||||
}
|
||||
|
||||
if stats.UnhealthyCount != 0 {
|
||||
t.Errorf("Expected 0 unhealthy connections, got %d", stats.UnhealthyCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerStatsAfterClose(t *testing.T) {
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
|
||||
// Close the manager
|
||||
err = mgr.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to close manager: %v", err)
|
||||
}
|
||||
|
||||
// Stats should show no connections
|
||||
stats := mgr.Stats()
|
||||
if stats.TotalConnections != 0 {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
111
pkg/dbmanager/providers/existing_db.go
Normal file
111
pkg/dbmanager/providers/existing_db.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// ExistingDBProvider wraps an existing *sql.DB connection
|
||||
// This allows using dbmanager features with a database connection
|
||||
// that was opened outside of the dbmanager package
|
||||
type ExistingDBProvider struct {
|
||||
db *sql.DB
|
||||
name string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewExistingDBProvider creates a new provider wrapping an existing *sql.DB
|
||||
func NewExistingDBProvider(db *sql.DB, name string) *ExistingDBProvider {
|
||||
return &ExistingDBProvider{
|
||||
db: db,
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect verifies the existing database connection is valid
|
||||
// It does NOT create a new connection, but ensures the existing one works
|
||||
func (p *ExistingDBProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Verify the connection works
|
||||
if err := p.db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("failed to ping existing database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (p *ExistingDBProvider) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// HealthCheck verifies the connection is alive
|
||||
func (p *ExistingDBProvider) HealthCheck(ctx context.Context) error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// GetNative returns the wrapped *sql.DB
|
||||
func (p *ExistingDBProvider) GetNative() (*sql.DB, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
// GetMongo returns an error since this is a SQL database
|
||||
func (p *ExistingDBProvider) GetMongo() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection statistics
|
||||
func (p *ExistingDBProvider) Stats() *ConnectionStats {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
stats := &ConnectionStats{
|
||||
Name: p.name,
|
||||
Type: "sql", // Generic since we don't know the specific type
|
||||
Connected: p.db != nil,
|
||||
}
|
||||
|
||||
if p.db != nil {
|
||||
dbStats := p.db.Stats()
|
||||
stats.OpenConnections = dbStats.OpenConnections
|
||||
stats.InUse = dbStats.InUse
|
||||
stats.Idle = dbStats.Idle
|
||||
stats.WaitCount = dbStats.WaitCount
|
||||
stats.WaitDuration = dbStats.WaitDuration
|
||||
stats.MaxIdleClosed = dbStats.MaxIdleClosed
|
||||
stats.MaxLifetimeClosed = dbStats.MaxLifetimeClosed
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestNewExistingDBProvider(t *testing.T) {
|
||||
// Open a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create provider
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
if provider == nil {
|
||||
t.Fatal("Expected provider to be created")
|
||||
}
|
||||
|
||||
if provider.name != "test-db" {
|
||||
t.Errorf("Expected name 'test-db', got '%s'", provider.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Connect(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect should verify the connection works
|
||||
err = provider.Connect(ctx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Connect_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
err := provider.Connect(ctx, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected Connect to fail with nil database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetNative(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
nativeDB, err := provider.GetNative()
|
||||
if err != nil {
|
||||
t.Errorf("Expected GetNative to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if nativeDB != db {
|
||||
t.Error("Expected GetNative to return the same database instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetNative_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
|
||||
_, err := provider.GetNative()
|
||||
if err == nil {
|
||||
t.Error("Expected GetNative to fail with nil database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_HealthCheck(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
err = provider.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_HealthCheck_ClosedDB(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
// Close the database
|
||||
db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
err = provider.HealthCheck(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected HealthCheck to fail with closed database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetMongo(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
_, err = provider.GetMongo()
|
||||
if err != ErrNotMongoDB {
|
||||
t.Errorf("Expected ErrNotMongoDB, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Stats(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Set some connection pool settings to test stats
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
stats := provider.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.Name != "test-db" {
|
||||
t.Errorf("Expected stats.Name to be 'test-db', got '%s'", stats.Name)
|
||||
}
|
||||
|
||||
if stats.Type != "sql" {
|
||||
t.Errorf("Expected stats.Type to be 'sql', got '%s'", stats.Type)
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected stats.Connected to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Close(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
err = provider.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Close to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the database is closed
|
||||
err = db.Ping()
|
||||
if err == nil {
|
||||
t.Error("Expected database to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Close_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
|
||||
err := provider.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Close to succeed with nil database, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -76,8 +76,12 @@ func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) erro
|
||||
// Don't fail connection if WAL mode cannot be enabled
|
||||
}
|
||||
|
||||
// Set busy timeout to handle locked database
|
||||
_, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000")
|
||||
// Set busy timeout to handle locked database (minimum 2 minutes = 120000ms)
|
||||
busyTimeout := cfg.GetQueryTimeout().Milliseconds()
|
||||
if busyTimeout < 120000 {
|
||||
busyTimeout = 120000 // Enforce minimum of 2 minutes
|
||||
}
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout=%d", busyTimeout))
|
||||
if err != nil {
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
@@ -123,27 +125,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
// Execute BeforeQueryList hook
|
||||
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
||||
logger.Error("BeforeQueryList hook failed: %v", err)
|
||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if hook aborted the operation
|
||||
if hookCtx.Abort {
|
||||
if hookCtx.AbortCode == 0 {
|
||||
hookCtx.AbortCode = http.StatusBadRequest
|
||||
}
|
||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified SQL query and variables from hooks
|
||||
sqlquery = hookCtx.SQLQuery
|
||||
variables = hookCtx.Variables
|
||||
// complexAPI = hookCtx.ComplexAPI
|
||||
|
||||
// Extract input variables from SQL query (placeholders like [variable])
|
||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||
|
||||
@@ -203,6 +184,27 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
|
||||
// Execute query within transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Set transaction in hook context for hooks to use
|
||||
hookCtx.Tx = tx
|
||||
|
||||
// Execute BeforeQueryList hook (inside transaction)
|
||||
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
||||
logger.Error("BeforeQueryList hook failed: %v", err)
|
||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if hook aborted the operation
|
||||
if hookCtx.Abort {
|
||||
if hookCtx.AbortCode == 0 {
|
||||
hookCtx.AbortCode = http.StatusBadRequest
|
||||
}
|
||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||
return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage)
|
||||
}
|
||||
|
||||
// Use potentially modified SQL query from hook
|
||||
sqlquery = hookCtx.SQLQuery
|
||||
sqlqueryCnt := sqlquery
|
||||
|
||||
// Parse sorting and pagination parameters
|
||||
@@ -286,6 +288,21 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
}
|
||||
total = hookCtx.Total
|
||||
|
||||
// Execute AfterQueryList hook (inside transaction)
|
||||
hookCtx.Result = dbobjlist
|
||||
hookCtx.Total = total
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
||||
logger.Error("AfterQueryList hook failed: %v", err)
|
||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return err
|
||||
}
|
||||
// Use potentially modified result from hook
|
||||
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||
dbobjlist = modifiedResult
|
||||
}
|
||||
total = hookCtx.Total
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -294,21 +311,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
return
|
||||
}
|
||||
|
||||
// Execute AfterQueryList hook
|
||||
hookCtx.Result = dbobjlist
|
||||
hookCtx.Total = total
|
||||
hookCtx.Error = err
|
||||
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
||||
logger.Error("AfterQueryList hook failed: %v", err)
|
||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
// Use potentially modified result from hook
|
||||
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||
dbobjlist = modifiedResult
|
||||
}
|
||||
total = hookCtx.Total
|
||||
|
||||
// Set response headers
|
||||
respOffset := 0
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
@@ -459,26 +461,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
// Execute BeforeQuery hook
|
||||
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
||||
logger.Error("BeforeQuery hook failed: %v", err)
|
||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if hook aborted the operation
|
||||
if hookCtx.Abort {
|
||||
if hookCtx.AbortCode == 0 {
|
||||
hookCtx.AbortCode = http.StatusBadRequest
|
||||
}
|
||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified SQL query and variables from hooks
|
||||
sqlquery = hookCtx.SQLQuery
|
||||
variables = hookCtx.Variables
|
||||
|
||||
// Extract input variables from SQL query
|
||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||
|
||||
@@ -554,6 +536,28 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
|
||||
// Execute query within transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Set transaction in hook context for hooks to use
|
||||
hookCtx.Tx = tx
|
||||
|
||||
// Execute BeforeQuery hook (inside transaction)
|
||||
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
||||
logger.Error("BeforeQuery hook failed: %v", err)
|
||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if hook aborted the operation
|
||||
if hookCtx.Abort {
|
||||
if hookCtx.AbortCode == 0 {
|
||||
hookCtx.AbortCode = http.StatusBadRequest
|
||||
}
|
||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||
return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage)
|
||||
}
|
||||
|
||||
// Use potentially modified SQL query from hook
|
||||
sqlquery = hookCtx.SQLQuery
|
||||
|
||||
// Execute BeforeSQLExec hook
|
||||
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
||||
logger.Error("BeforeSQLExec hook failed: %v", err)
|
||||
@@ -586,6 +590,19 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
dbobj = modifiedResult
|
||||
}
|
||||
|
||||
// Execute AfterQuery hook (inside transaction)
|
||||
hookCtx.Result = dbobj
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
||||
logger.Error("AfterQuery hook failed: %v", err)
|
||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return err
|
||||
}
|
||||
// Use potentially modified result from hook
|
||||
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||
dbobj = modifiedResult
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -594,19 +611,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
return
|
||||
}
|
||||
|
||||
// Execute AfterQuery hook
|
||||
hookCtx.Result = dbobj
|
||||
hookCtx.Error = err
|
||||
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
||||
logger.Error("AfterQuery hook failed: %v", err)
|
||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
// Use potentially modified result from hook
|
||||
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||
dbobj = modifiedResult
|
||||
}
|
||||
|
||||
// Execute BeforeResponse hook
|
||||
hookCtx.Result = dbobj
|
||||
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
||||
@@ -1097,9 +1101,25 @@ func normalizePostgresValue(value interface{}) interface{} {
|
||||
case map[string]interface{}:
|
||||
// Recursively normalize nested maps
|
||||
return normalizePostgresTypes(v)
|
||||
|
||||
case string:
|
||||
var jsonObj interface{}
|
||||
if err := json.Unmarshal([]byte(v), &jsonObj); err == nil {
|
||||
// It's valid JSON, return as json.RawMessage so it's not double-encoded
|
||||
return json.RawMessage(v)
|
||||
}
|
||||
return v
|
||||
case uuid.UUID:
|
||||
return v.String()
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339)
|
||||
case bool, int, int8, int16, int32, int64, float32, float64, uint, uint8, uint16, uint32, uint64:
|
||||
return v
|
||||
default:
|
||||
// For other types (int, float, string, bool, etc.), return as-is
|
||||
// For other types (int, float, bool, etc.), return as-is
|
||||
// Check stringers
|
||||
if str, ok := v.(fmt.Stringer); ok {
|
||||
return str.String()
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
@@ -46,6 +47,10 @@ type HookContext struct {
|
||||
// User context
|
||||
UserContext *security.UserContext
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
|
||||
// Pagination and filtering (for list queries)
|
||||
SortColumns string
|
||||
Limit int
|
||||
|
||||
@@ -645,11 +645,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
|
||||
// Database operation helpers (adapted from websocketspec)
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
tableName = schema + "_" + tableName
|
||||
} else {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package reflection
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Len(v any) int {
|
||||
val := reflect.ValueOf(v)
|
||||
@@ -47,3 +50,58 @@ func ExtractTableNameOnly(fullName string) string {
|
||||
|
||||
return fullName[startIndex:]
|
||||
}
|
||||
|
||||
// GetPointerElement returns the element type if the provided reflect.Type is a pointer.
|
||||
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
||||
// If neither condition is met, it returns the original type.
|
||||
func GetPointerElement(v reflect.Type) reflect.Type {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
return v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr {
|
||||
subElem := v.Elem()
|
||||
if subElem.Elem().Kind() == reflect.Ptr {
|
||||
return subElem.Elem().Elem()
|
||||
}
|
||||
return v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetJSONNameForField gets the JSON tag name for a struct field.
|
||||
// Returns the JSON field name from the json struct tag, or an empty string if not found.
|
||||
// Handles the "json" tag format: "name", "name,omitempty", etc.
|
||||
func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
||||
if modelType == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle pointer types
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the field
|
||||
field, found := modelType.FieldByName(fieldName)
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get the JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the tag (format: "name,omitempty" or just "name")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -584,11 +584,23 @@ func ExtractSourceColumn(colName string) string {
|
||||
}
|
||||
|
||||
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||
// Handles consecutive uppercase letters (acronyms) correctly:
|
||||
// "HTTPServer" -> "http_server", "UserID" -> "user_id", "MyHTTPServer" -> "my_http_server"
|
||||
func ToSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
runes := []rune(s)
|
||||
|
||||
for i, r := range runes {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune('_')
|
||||
// Add underscore if:
|
||||
// 1. Previous character is lowercase, OR
|
||||
// 2. Next character is lowercase (transition from acronym to word)
|
||||
prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z'
|
||||
nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z'
|
||||
|
||||
if prevIsLower || nextIsLower {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
@@ -936,32 +948,38 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||
// Build list of possible column names for this field
|
||||
var columnNames []string
|
||||
|
||||
// 1. Bun tag
|
||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Gorm tag
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. JSON tag
|
||||
// 1. JSON tag (primary - most common)
|
||||
jsonFound := false
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
columnNames = append(columnNames, parts[0])
|
||||
jsonFound = true
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Field name variations
|
||||
// 2. Bun tag (fallback if no JSON tag)
|
||||
if !jsonFound {
|
||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Gorm tag (fallback if no JSON tag)
|
||||
if !jsonFound {
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Field name variations (last resort)
|
||||
columnNames = append(columnNames, field.Name)
|
||||
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||
columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||
// columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||
|
||||
// Map all column name variations to this field index
|
||||
for _, colName := range columnNames {
|
||||
@@ -1067,7 +1085,7 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
case string:
|
||||
field.SetBytes([]byte(v))
|
||||
return nil
|
||||
case map[string]interface{}, []interface{}:
|
||||
case map[string]interface{}, []interface{}, []*any, map[string]*any:
|
||||
// Marshal complex types to JSON for SqlJSONB fields
|
||||
jsonBytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
@@ -1077,6 +1095,17 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle slice-to-slice conversions (e.g., []interface{} to []*SomeModel)
|
||||
if valueReflect.Kind() == reflect.Slice {
|
||||
return convertSlice(field, valueReflect)
|
||||
}
|
||||
}
|
||||
|
||||
// If we can convert the type, do it
|
||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||
field.Set(valueReflect.Convert(field.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||
@@ -1090,9 +1119,9 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
// Call the Scan method with the value
|
||||
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
||||
if len(results) > 0 {
|
||||
// Check if there was an error
|
||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||
return err
|
||||
// The Scan method returns error - check if it's nil
|
||||
if !results[0].IsNil() {
|
||||
return results[0].Interface().(error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1147,13 +1176,93 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
|
||||
}
|
||||
|
||||
// If we can convert the type, do it
|
||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||
field.Set(valueReflect.Convert(field.Type()))
|
||||
return nil
|
||||
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||
}
|
||||
|
||||
// convertSlice converts a source slice to a target slice type, handling element-wise conversions
|
||||
// Supports converting []interface{} to slices of structs or pointers to structs
|
||||
func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error {
|
||||
if sourceSlice.Kind() != reflect.Slice || targetSlice.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("both source and target must be slices")
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||
// Get the element type of the target slice
|
||||
targetElemType := targetSlice.Type().Elem()
|
||||
sourceLen := sourceSlice.Len()
|
||||
|
||||
// Create a new slice with the same length as the source
|
||||
newSlice := reflect.MakeSlice(targetSlice.Type(), sourceLen, sourceLen)
|
||||
|
||||
// Convert each element
|
||||
for i := 0; i < sourceLen; i++ {
|
||||
sourceElem := sourceSlice.Index(i)
|
||||
targetElem := newSlice.Index(i)
|
||||
|
||||
// Get the actual value from the source element
|
||||
var sourceValue interface{}
|
||||
if sourceElem.CanInterface() {
|
||||
sourceValue = sourceElem.Interface()
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle nil elements
|
||||
if sourceValue == nil {
|
||||
// For pointer types, nil is valid
|
||||
if targetElemType.Kind() == reflect.Ptr {
|
||||
targetElem.Set(reflect.Zero(targetElemType))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If target element type is a pointer to struct, we need to create new instances
|
||||
if targetElemType.Kind() == reflect.Ptr {
|
||||
// Create a new instance of the pointed-to type
|
||||
newElemPtr := reflect.New(targetElemType.Elem())
|
||||
|
||||
// Convert the source value to the struct
|
||||
switch sv := sourceValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Source is a map, use MapToStruct to populate the new instance
|
||||
if err := MapToStruct(sv, newElemPtr.Interface()); err != nil {
|
||||
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||
}
|
||||
default:
|
||||
// Try direct conversion or setFieldValue
|
||||
if err := setFieldValue(newElemPtr.Elem(), sourceValue); err != nil {
|
||||
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
targetElem.Set(newElemPtr)
|
||||
} else if targetElemType.Kind() == reflect.Struct {
|
||||
// Target element is a struct (not a pointer)
|
||||
switch sv := sourceValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Use MapToStruct to populate the element
|
||||
elemPtr := targetElem.Addr()
|
||||
if elemPtr.CanInterface() {
|
||||
if err := MapToStruct(sv, elemPtr.Interface()); err != nil {
|
||||
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Try direct conversion
|
||||
if err := setFieldValue(targetElem, sourceValue); err != nil {
|
||||
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For other types, use setFieldValue
|
||||
if err := setFieldValue(targetElem, sourceValue); err != nil {
|
||||
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the converted slice to the target field
|
||||
targetSlice.Set(newSlice)
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertToInt64 attempts to convert various types to int64
|
||||
@@ -1261,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
|
||||
// This can be used to validate input data against a model's structure
|
||||
// The map keys are the JSON field names (from json tags) that exist in the model
|
||||
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||
validFields := make(map[string]bool)
|
||||
|
||||
// Unwrap pointers to get to the base struct type
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return validFields
|
||||
}
|
||||
|
||||
collectValidFieldNames(modelType, validFields)
|
||||
return validFields
|
||||
}
|
||||
|
||||
// collectValidFieldNames recursively collects valid JSON field names from a struct type
|
||||
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for embedded structs
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
// Recursively add fields from embedded struct
|
||||
collectValidFieldNames(fieldType, validFields)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get the JSON tag name for this field (same logic as MapToStruct)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract the field name from the JSON tag (before any options like omitempty)
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
validFields[parts[0]] = true
|
||||
}
|
||||
} else {
|
||||
// If no JSON tag, use the field name in lowercase as a fallback
|
||||
validFields[strings.ToLower(field.Name)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
|
||||
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package reflection_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func TestMapToStruct_StandardSqlNullTypes(t *testing.T) {
|
||||
// Test model with standard library sql.Null* types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||
Name sql.NullString `bun:"name" json:"name"`
|
||||
Score sql.NullFloat64 `bun:"score" json:"score"`
|
||||
Active sql.NullBool `bun:"active" json:"active"`
|
||||
UpdatedAt sql.NullTime `bun:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
dataMap := map[string]any{
|
||||
"id": int64(100),
|
||||
"age": int64(25),
|
||||
"name": "John Doe",
|
||||
"score": 95.5,
|
||||
"active": true,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify ID
|
||||
if result.ID != 100 {
|
||||
t.Errorf("ID = %v, want 100", result.ID)
|
||||
}
|
||||
|
||||
// Verify Age (sql.NullInt64)
|
||||
if !result.Age.Valid {
|
||||
t.Error("Age.Valid = false, want true")
|
||||
}
|
||||
if result.Age.Int64 != 25 {
|
||||
t.Errorf("Age.Int64 = %v, want 25", result.Age.Int64)
|
||||
}
|
||||
|
||||
// Verify Name (sql.NullString)
|
||||
if !result.Name.Valid {
|
||||
t.Error("Name.Valid = false, want true")
|
||||
}
|
||||
if result.Name.String != "John Doe" {
|
||||
t.Errorf("Name.String = %v, want 'John Doe'", result.Name.String)
|
||||
}
|
||||
|
||||
// Verify Score (sql.NullFloat64)
|
||||
if !result.Score.Valid {
|
||||
t.Error("Score.Valid = false, want true")
|
||||
}
|
||||
if result.Score.Float64 != 95.5 {
|
||||
t.Errorf("Score.Float64 = %v, want 95.5", result.Score.Float64)
|
||||
}
|
||||
|
||||
// Verify Active (sql.NullBool)
|
||||
if !result.Active.Valid {
|
||||
t.Error("Active.Valid = false, want true")
|
||||
}
|
||||
if !result.Active.Bool {
|
||||
t.Error("Active.Bool = false, want true")
|
||||
}
|
||||
|
||||
// Verify UpdatedAt (sql.NullTime)
|
||||
if !result.UpdatedAt.Valid {
|
||||
t.Error("UpdatedAt.Valid = false, want true")
|
||||
}
|
||||
if !result.UpdatedAt.Time.Equal(now) {
|
||||
t.Errorf("UpdatedAt.Time = %v, want %v", result.UpdatedAt.Time, now)
|
||||
}
|
||||
|
||||
t.Log("All standard library sql.Null* types handled correctly!")
|
||||
}
|
||||
|
||||
func TestMapToStruct_StandardSqlNullTypes_WithNil(t *testing.T) {
|
||||
// Test nil handling for standard library sql.Null* types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||
Name sql.NullString `bun:"name" json:"name"`
|
||||
}
|
||||
|
||||
dataMap := map[string]any{
|
||||
"id": int64(200),
|
||||
"age": int64(30),
|
||||
"name": nil, // Explicitly nil
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Age should be valid
|
||||
if !result.Age.Valid {
|
||||
t.Error("Age.Valid = false, want true")
|
||||
}
|
||||
if result.Age.Int64 != 30 {
|
||||
t.Errorf("Age.Int64 = %v, want 30", result.Age.Int64)
|
||||
}
|
||||
|
||||
// Name should be invalid (null)
|
||||
if result.Name.Valid {
|
||||
t.Error("Name.Valid = true, want false (null)")
|
||||
}
|
||||
|
||||
t.Log("Nil handling for sql.Null* types works correctly!")
|
||||
}
|
||||
364
pkg/reflection/spectypes_integration_test.go
Normal file
364
pkg/reflection/spectypes_integration_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestModel contains all spectypes custom types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Name spectypes.SqlString `bun:"name" json:"name"`
|
||||
Age spectypes.SqlInt64 `bun:"age" json:"age"`
|
||||
Score spectypes.SqlFloat64 `bun:"score" json:"score"`
|
||||
Active spectypes.SqlBool `bun:"active" json:"active"`
|
||||
UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"`
|
||||
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||
StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"`
|
||||
Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"`
|
||||
Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"`
|
||||
Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"`
|
||||
}
|
||||
|
||||
// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly
|
||||
func TestMapToStruct_AllSpectypes(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testTime := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataMap map[string]interface{}
|
||||
validator func(*testing.T, *TestModel)
|
||||
}{
|
||||
{
|
||||
name: "SqlString from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Name.Valid || m.Name.String() != "John Doe" {
|
||||
t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt64 from int64",
|
||||
dataMap: map[string]interface{}{
|
||||
"age": int64(42),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Age.Valid || m.Age.Int64() != 42 {
|
||||
t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt64 from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"age": "99",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Age.Valid || m.Age.Int64() != 99 {
|
||||
t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlFloat64 from float64",
|
||||
dataMap: map[string]interface{}{
|
||||
"score": float64(98.5),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Score.Valid || m.Score.Float64() != 98.5 {
|
||||
t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlBool from bool",
|
||||
dataMap: map[string]interface{}{
|
||||
"active": true,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Active.Valid || !m.Active.Bool() {
|
||||
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlUUID from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"uuid": testUUID.String(),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.UUID.Valid || m.UUID.UUID() != testUUID {
|
||||
t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTimeStamp from time.Time",
|
||||
dataMap: map[string]interface{}{
|
||||
"created_at": testTime,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Errorf("expected created_at to be valid")
|
||||
}
|
||||
// Check if times are close enough (within a second)
|
||||
diff := m.CreatedAt.Time().Sub(testTime)
|
||||
if diff < -time.Second || diff > time.Second {
|
||||
t.Errorf("time difference too large: %v", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTimeStamp from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"created_at": "2024-01-15T10:30:00",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Errorf("expected created_at to be valid")
|
||||
}
|
||||
expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
if m.CreatedAt.Time().Year() != expected.Year() ||
|
||||
m.CreatedAt.Time().Month() != expected.Month() ||
|
||||
m.CreatedAt.Time().Day() != expected.Day() {
|
||||
t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlDate from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"birth_date": "2000-05-20",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.BirthDate.Valid {
|
||||
t.Errorf("expected birth_date to be valid")
|
||||
}
|
||||
expected := "2000-05-20"
|
||||
if m.BirthDate.String() != expected {
|
||||
t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTime from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"start_time": "14:30:00",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.StartTime.Valid {
|
||||
t.Errorf("expected start_time to be valid")
|
||||
}
|
||||
if m.StartTime.String() != "14:30:00" {
|
||||
t.Errorf("expected time=14:30:00, got %s", m.StartTime.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from map",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
},
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
asMap, err := m.Metadata.AsMap()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||
}
|
||||
if asMap["key1"] != "value1" {
|
||||
t.Errorf("expected key1=value1, got %v", asMap["key1"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": `{"test":"data"}`,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
asMap, err := m.Metadata.AsMap()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||
}
|
||||
if asMap["test"] != "data" {
|
||||
t.Errorf("expected test=data, got %v", asMap["test"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from []byte",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": []byte(`{"byte":"array"}`),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
if string(m.Metadata) != `{"byte":"array"}` {
|
||||
t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt16 from int16",
|
||||
dataMap: map[string]interface{}{
|
||||
"count16": int16(100),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Count16.Valid || m.Count16.Int64() != 100 {
|
||||
t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt32 from int32",
|
||||
dataMap: map[string]interface{}{
|
||||
"count32": int32(5000),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Count32.Valid || m.Count32.Int64() != 5000 {
|
||||
t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil values create invalid nulls",
|
||||
dataMap: map[string]interface{}{
|
||||
"name": nil,
|
||||
"age": nil,
|
||||
"active": nil,
|
||||
"created_at": nil,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if m.Name.Valid {
|
||||
t.Error("expected name to be invalid for nil value")
|
||||
}
|
||||
if m.Age.Valid {
|
||||
t.Error("expected age to be invalid for nil value")
|
||||
}
|
||||
if m.Active.Valid {
|
||||
t.Error("expected active to be invalid for nil value")
|
||||
}
|
||||
if m.CreatedAt.Valid {
|
||||
t.Error("expected created_at to be invalid for nil value")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all types together",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(1),
|
||||
"name": "Test User",
|
||||
"age": int64(30),
|
||||
"score": float64(95.7),
|
||||
"active": true,
|
||||
"uuid": testUUID.String(),
|
||||
"created_at": "2024-01-15T10:30:00",
|
||||
"birth_date": "1994-06-15",
|
||||
"start_time": "09:00:00",
|
||||
"metadata": map[string]interface{}{"role": "admin"},
|
||||
"count16": int16(50),
|
||||
"count32": int32(1000),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if m.ID != 1 {
|
||||
t.Errorf("expected id=1, got %d", m.ID)
|
||||
}
|
||||
if !m.Name.Valid || m.Name.String() != "Test User" {
|
||||
t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||
}
|
||||
if !m.Age.Valid || m.Age.Int64() != 30 {
|
||||
t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
if !m.Score.Valid || m.Score.Float64() != 95.7 {
|
||||
t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||
}
|
||||
if !m.Active.Valid || !m.Active.Bool() {
|
||||
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||
}
|
||||
if !m.UUID.Valid {
|
||||
t.Error("expected uuid to be valid")
|
||||
}
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Error("expected created_at to be valid")
|
||||
}
|
||||
if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" {
|
||||
t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String())
|
||||
}
|
||||
if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" {
|
||||
t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String())
|
||||
}
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Error("expected metadata to have data")
|
||||
}
|
||||
if !m.Count16.Valid || m.Count16.Int64() != 50 {
|
||||
t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||
}
|
||||
if !m.Count32.Valid || m.Count32.Int64() != 1000 {
|
||||
t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := &TestModel{}
|
||||
if err := MapToStruct(tt.dataMap, model); err != nil {
|
||||
t.Fatalf("MapToStruct failed: %v", err)
|
||||
}
|
||||
tt.validator(t, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields
|
||||
func TestMapToStruct_PartialUpdate(t *testing.T) {
|
||||
// Create initial model with some values
|
||||
initial := &TestModel{
|
||||
ID: 1,
|
||||
Name: spectypes.NewSqlString("Original Name"),
|
||||
Age: spectypes.NewSqlInt64(25),
|
||||
}
|
||||
|
||||
// Update only the age field
|
||||
partialUpdate := map[string]interface{}{
|
||||
"age": int64(30),
|
||||
}
|
||||
|
||||
// Apply partial update
|
||||
if err := MapToStruct(partialUpdate, initial); err != nil {
|
||||
t.Fatalf("MapToStruct failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify age was updated
|
||||
if !initial.Age.Valid || initial.Age.Int64() != 30 {
|
||||
t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64())
|
||||
}
|
||||
|
||||
// Verify name was preserved (not overwritten with zero value)
|
||||
if !initial.Name.Valid || initial.Name.String() != "Original Name" {
|
||||
t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String())
|
||||
}
|
||||
|
||||
// Verify ID was preserved
|
||||
if initial.ID != 1 {
|
||||
t.Errorf("expected id=1 to be preserved, got %d", initial.ID)
|
||||
}
|
||||
}
|
||||
572
pkg/resolvespec/EXAMPLES.md
Normal file
572
pkg/resolvespec/EXAMPLES.md
Normal file
@@ -0,0 +1,572 @@
|
||||
# ResolveSpec Query Features Examples
|
||||
|
||||
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
|
||||
|
||||
## OR Logic in Filters (SearchOr)
|
||||
|
||||
### Basic OR Filter Example
|
||||
|
||||
Find all users with status "active" OR "pending":
|
||||
|
||||
```json
|
||||
POST /users
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combined AND/OR Filters
|
||||
|
||||
Find users with (status="active" OR status="pending") AND age >= 18:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
**Important Notes:**
|
||||
- By default, filters use AND logic
|
||||
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
|
||||
- This grouping ensures OR conditions don't interfere with AND conditions
|
||||
- You don't need to specify `"logic_operator": "AND"` as it's the default
|
||||
|
||||
### Multiple OR Groups
|
||||
|
||||
You can have multiple separate OR groups:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "high"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "urgent",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
|
||||
|
||||
## Custom Operators
|
||||
|
||||
### Simple Custom SQL Condition
|
||||
|
||||
Filter by email domain using custom SQL:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "company_emails",
|
||||
"sql": "email LIKE '%@company.com'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Custom Operators
|
||||
|
||||
Combine multiple custom SQL conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_active",
|
||||
"sql": "last_login > NOW() - INTERVAL '30 days'"
|
||||
},
|
||||
{
|
||||
"name": "high_score",
|
||||
"sql": "score > 1000"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Complex Custom Operator
|
||||
|
||||
Use complex SQL expressions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "priority_users",
|
||||
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combining Custom Operators with Regular Filters
|
||||
|
||||
Mix custom operators with standard filters:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "active_last_month",
|
||||
"sql": "last_activity > NOW() - INTERVAL '1 month'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Row Numbers
|
||||
|
||||
### Two Ways to Get Row Numbers
|
||||
|
||||
There are two different features for row numbers:
|
||||
|
||||
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
|
||||
2. **`RowNumber` field in models** - Automatically number all records in the response
|
||||
|
||||
### 1. FetchRowNumber - Get Position of Specific Record
|
||||
|
||||
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Contains ONLY the specified user:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Alice Smith",
|
||||
"score": 9850,
|
||||
"level": 42
|
||||
},
|
||||
"metadata": {
|
||||
"total": 10000,
|
||||
"count": 1,
|
||||
"filtered": 10000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
|
||||
|
||||
### Row Number with Filters
|
||||
|
||||
Find position within a filtered subset (e.g., "What's my rank in my country?"):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Bob Johnson",
|
||||
"country": "USA",
|
||||
"score": 7200,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 2500,
|
||||
"count": 1,
|
||||
"filtered": 2500,
|
||||
"row_number": 156
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
|
||||
|
||||
### 2. RowNumber Field - Auto-Number All Records
|
||||
|
||||
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
|
||||
|
||||
**Model Definition:**
|
||||
```go
|
||||
type Player struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Score int64 `json:"score"`
|
||||
RowNumber int64 `json:"row_number"` // Will be auto-populated
|
||||
}
|
||||
```
|
||||
|
||||
**Request (with pagination):**
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [{"column": "score", "direction": "desc"}],
|
||||
"limit": 10,
|
||||
"offset": 20
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - RowNumber automatically set:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [
|
||||
{
|
||||
"id": 456,
|
||||
"name": "Player21",
|
||||
"score": 8900,
|
||||
"row_number": 21
|
||||
},
|
||||
{
|
||||
"id": 789,
|
||||
"name": "Player22",
|
||||
"score": 8850,
|
||||
"row_number": 22
|
||||
},
|
||||
{
|
||||
"id": 123,
|
||||
"name": "Player23",
|
||||
"score": 8800,
|
||||
"row_number": 23
|
||||
}
|
||||
// ... records 24-30 ...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**How It Works:**
|
||||
- `row_number = offset + index + 1` (1-based)
|
||||
- With offset=20, first record gets row_number=21
|
||||
- With offset=20, second record gets row_number=22
|
||||
- Perfect for displaying "Rank" in paginated tables
|
||||
|
||||
**Use Case:** Displaying leaderboards with rank numbers:
|
||||
```
|
||||
Rank | Player | Score
|
||||
-----|-----------|-------
|
||||
21 | Player21 | 8900
|
||||
22 | Player22 | 8850
|
||||
23 | Player23 | 8800
|
||||
```
|
||||
|
||||
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
|
||||
|
||||
### When to Use Each Feature
|
||||
|
||||
| Feature | Use Case | Returns | Performance |
|
||||
|---------|----------|---------|-------------|
|
||||
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
|
||||
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
|
||||
|
||||
**Combined Example - Full Leaderboard UI:**
|
||||
|
||||
```javascript
|
||||
// Request 1: Get current user's rank
|
||||
const userRank = await api.read({
|
||||
fetch_row_number: currentUserId,
|
||||
sort: [{column: "score", direction: "desc"}]
|
||||
});
|
||||
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
|
||||
|
||||
// Request 2: Get top 10 with rank numbers
|
||||
const top10 = await api.read({
|
||||
sort: [{column: "score", direction: "desc"}],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
|
||||
|
||||
// Display:
|
||||
// "Your Rank: #156"
|
||||
// "Top Players:"
|
||||
// "#1 - Alice - 9999"
|
||||
// "#2 - Bob - 9876"
|
||||
// ...
|
||||
```
|
||||
|
||||
## Complete Example: Advanced Query
|
||||
|
||||
Combine all features for a complex query:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email", "score", "status"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "trial",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "score",
|
||||
"operator": "gte",
|
||||
"value": 100
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_activity",
|
||||
"sql": "last_login > NOW() - INTERVAL '7 days'"
|
||||
},
|
||||
{
|
||||
"name": "verified_email",
|
||||
"sql": "email_verified = true"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345",
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This query:
|
||||
- Selects specific columns
|
||||
- Filters for users with status "active" OR "trial"
|
||||
- AND score >= 100
|
||||
- Applies custom SQL conditions for recent activity and verified emails
|
||||
- Sorts by score (descending) then creation date (ascending)
|
||||
- Returns the row number of user "12345" in this filtered/sorted set
|
||||
- Returns 50 records starting from the first one
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Leaderboards - Get Current User's Rank
|
||||
|
||||
Get the current user's position and data (returns only their record):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "game_id",
|
||||
"operator": "eq",
|
||||
"value": "game123"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "current_user_id"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Tip:** For full leaderboards, make two requests:
|
||||
1. One with `fetch_row_number` to get user's rank
|
||||
2. One with `limit` and `offset` to get top players list
|
||||
|
||||
### 2. Multi-Status Search
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "pending"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "processing",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "shipped",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Advanced Date Filtering
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "this_month",
|
||||
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
|
||||
},
|
||||
{
|
||||
"name": "business_hours",
|
||||
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
|
||||
|
||||
1. **Never** directly interpolate user input into custom operator SQL
|
||||
2. Always validate and sanitize custom operator SQL on the backend
|
||||
3. Consider using a whitelist of allowed custom operators
|
||||
4. Use prepared statements or parameterized queries when possible
|
||||
5. Implement proper authorization checks before executing queries
|
||||
|
||||
Example of safe custom operator handling in Go:
|
||||
|
||||
```go
|
||||
// Whitelist of allowed custom operators
|
||||
allowedOperators := map[string]string{
|
||||
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
|
||||
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
|
||||
"premium_only": "subscription_level = 'premium'",
|
||||
}
|
||||
|
||||
// Validate custom operators from request
|
||||
for _, op := range req.Options.CustomOperators {
|
||||
if sql, ok := allowedOperators[op.Name]; ok {
|
||||
op.SQL = sql // Use whitelisted SQL
|
||||
} else {
|
||||
return errors.New("custom operator not allowed: " + op.Name)
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -214,6 +214,146 @@ Content-Type: application/json
|
||||
|
||||
```json
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
|
||||
|
||||
### Custom Operators
|
||||
|
||||
Add custom SQL conditions when needed:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "email_domain_filter",
|
||||
"sql": "LOWER(email) LIKE '%@example.com'"
|
||||
},
|
||||
{
|
||||
"name": "recent_records",
|
||||
"sql": "created_at > NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Custom operators are applied as additional WHERE conditions to your query.
|
||||
|
||||
### Fetch Row Number
|
||||
|
||||
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Returns ONLY the specified record with its position:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "John Doe",
|
||||
"score": 850,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 1000,
|
||||
"count": 1,
|
||||
"filtered": 1000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
|
||||
|
||||
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
|
||||
|
||||
## Preloading
|
||||
|
||||
Load related entities with custom configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email"],
|
||||
"preload": [
|
||||
{
|
||||
"relation": "posts",
|
||||
"columns": ["id", "title", "created_at"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "published"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"limit": 5
|
||||
},
|
||||
{
|
||||
"relation": "profile",
|
||||
"columns": ["bio", "website"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
Efficient pagination for large datasets:
|
||||
|
||||
### First Request (No Cursor)
|
||||
|
||||
```json
|
||||
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
|
||||
// Check permissions
|
||||
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify query options
|
||||
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
|
||||
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
|
||||
}
|
||||
|
||||
return nil
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
})
|
||||
|
||||
// Register an after-read hook (e.g., for data transformation)
|
||||
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
|
||||
})
|
||||
// Transform or filter results
|
||||
if users, ok := ctx.Result.([]User); ok {
|
||||
for i := range users {
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register a before-create hook (e.g., for validation)
|
||||
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
|
||||
// Validate data
|
||||
if user, ok := ctx.Data.(*User); ok {
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
// Add timestamps
|
||||
|
||||
143
pkg/resolvespec/filter_test.go
Normal file
143
pkg/resolvespec/filter_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestBuildFilterCondition tests the filter condition builder
|
||||
func TestBuildFilterCondition(t *testing.T) {
|
||||
h := &Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filter common.FilterOption
|
||||
expectedCondition string
|
||||
expectedArgsCount int
|
||||
}{
|
||||
{
|
||||
name: "Equal operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "eq",
|
||||
Value: "active",
|
||||
},
|
||||
expectedCondition: "status = ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Greater than operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "age",
|
||||
Operator: "gt",
|
||||
Value: 18,
|
||||
},
|
||||
expectedCondition: "age > ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "IN operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "in",
|
||||
Value: []string{"active", "pending"},
|
||||
},
|
||||
expectedCondition: "status IN (?)",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "LIKE operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "email",
|
||||
Operator: "like",
|
||||
Value: "%@example.com",
|
||||
},
|
||||
expectedCondition: "email LIKE ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
condition, args := h.buildFilterCondition(tt.filter)
|
||||
|
||||
if condition != tt.expectedCondition {
|
||||
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
|
||||
}
|
||||
|
||||
if len(args) != tt.expectedArgsCount {
|
||||
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
|
||||
}
|
||||
|
||||
// Note: Skip value comparison for slices as they can't be compared with ==
|
||||
// The important part is that args are populated correctly
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestORGrouping tests that consecutive OR filters are properly grouped
|
||||
func TestORGrouping(t *testing.T) {
|
||||
// This is a conceptual test - in practice we'd need a mock SelectQuery
|
||||
// to verify the actual SQL grouping behavior
|
||||
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
|
||||
{Column: "age", Operator: "gte", Value: 18},
|
||||
}
|
||||
|
||||
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
|
||||
// The first three filters should be grouped together
|
||||
// The fourth filter should be separate with AND
|
||||
|
||||
// Count OR groups
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected one OR group
|
||||
if orGroupCount != 1 {
|
||||
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "priority", Operator: "eq", Value: "high"},
|
||||
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
|
||||
}
|
||||
|
||||
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
|
||||
// Should have two OR groups
|
||||
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected two OR groups
|
||||
if orGroupCount != 2 {
|
||||
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -280,10 +280,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
for _, filter := range options.Filters {
|
||||
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
|
||||
query = h.applyFilter(query, filter)
|
||||
// Apply filters with proper grouping for OR logic
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
@@ -318,6 +321,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
@@ -379,24 +384,105 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
// Handle FetchRowNumber if requested
|
||||
var rowNumber *int64
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Build ROW_NUMBER window function SQL
|
||||
rowNumberSQL := "ROW_NUMBER() OVER ("
|
||||
if len(options.Sort) > 0 {
|
||||
rowNumberSQL += "ORDER BY "
|
||||
for i, sort := range options.Sort {
|
||||
if i > 0 {
|
||||
rowNumberSQL += ", "
|
||||
}
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
|
||||
}
|
||||
}
|
||||
rowNumberSQL += ")"
|
||||
|
||||
// Create a query to fetch the row number using a subquery approach
|
||||
// We'll select the PK and row_number, then filter by the target ID
|
||||
type RowNumResult struct {
|
||||
RowNum int64 `bun:"row_num"`
|
||||
}
|
||||
|
||||
rowNumQuery := h.db.NewSelect().Table(tableName).
|
||||
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
|
||||
Column(pkName)
|
||||
|
||||
// Apply the same filters as the main query
|
||||
for _, filter := range options.Filters {
|
||||
rowNumQuery = h.applyFilter(rowNumQuery, filter)
|
||||
}
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
rowNumQuery = rowNumQuery.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Filter for the specific ID we want the row number for
|
||||
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
|
||||
|
||||
// Execute query to get row number
|
||||
var result RowNumResult
|
||||
if err := rowNumQuery.Scan(ctx, &result); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// Build filter description for error message
|
||||
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
|
||||
if len(options.CustomOperators) > 0 {
|
||||
customOps := make([]string, 0, len(options.CustomOperators))
|
||||
for _, op := range options.CustomOperators {
|
||||
customOps = append(customOps, op.SQL)
|
||||
}
|
||||
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
|
||||
}
|
||||
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
|
||||
} else {
|
||||
logger.Warn("Error fetching row number: %v", err)
|
||||
}
|
||||
} else {
|
||||
rowNumber = &result.RowNum
|
||||
logger.Debug("Found row number: %d", *rowNumber)
|
||||
}
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
|
||||
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
|
||||
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
var result interface{}
|
||||
if id != "" {
|
||||
logger.Debug("Querying single record with ID: %s", id)
|
||||
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
|
||||
// Single record query - either by URL ID or FetchRowNumber
|
||||
var targetID string
|
||||
if id != "" {
|
||||
targetID = id
|
||||
logger.Debug("Querying single record with URL ID: %s", id)
|
||||
} else {
|
||||
targetID = *options.FetchRowNumber
|
||||
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
|
||||
}
|
||||
|
||||
// For single record, create a new pointer to the struct type
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
logger.Error("Error querying record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||
@@ -416,20 +502,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
logger.Info("Successfully retrieved records")
|
||||
|
||||
// Build metadata
|
||||
limit := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
offset := 0
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
count := int64(total)
|
||||
|
||||
// When FetchRowNumber is used, we only return 1 record
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
count = 1
|
||||
// Set the fetched row number on the record
|
||||
if rowNumber != nil {
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
|
||||
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
}
|
||||
} else {
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Set row numbers on records if RowNumber field exists
|
||||
// Only for multiple records (not when fetching single record)
|
||||
h.setRowNumbersOnRecords(result, offset)
|
||||
}
|
||||
|
||||
h.sendResponse(w, result, &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
RowNumber: rowNumber,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -698,37 +803,133 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
}
|
||||
|
||||
// Standard processing without nested relations
|
||||
query := h.db.NewUpdate().Table(tableName).SetMap(updates)
|
||||
// Get the primary key name
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Apply conditions
|
||||
if urlID != "" {
|
||||
logger.Debug("Updating by URL ID: %s", urlID)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
logger.Debug("Updating by request ID: %s", id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
case []string:
|
||||
logger.Debug("Updating by multiple IDs: %v", id)
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
// Wrap in transaction to ensure BeforeUpdate hook is inside transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
||||
|
||||
// Apply conditions to select
|
||||
if urlID != "" {
|
||||
logger.Debug("Updating by URL ID: %s", urlID)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
logger.Debug("Updating by request ID: %s", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
if len(id) > 0 {
|
||||
logger.Debug("Updating by multiple IDs: %v", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: urlID,
|
||||
Data: updates,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||
for key, newValue := range updates {
|
||||
// Skip if the value is nil
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if the value is an empty string
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update the existing map with the new value
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
// Build update query with merged data
|
||||
query := tx.NewUpdate().Table(tableName).SetMap(existingMap)
|
||||
|
||||
// Apply conditions
|
||||
if urlID != "" {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record(s): %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = updates
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Update error: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
||||
if err.Error() == "no records found to update" {
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err)
|
||||
} else {
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
logger.Warn("No records found to update")
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
||||
logger.Info("Successfully updated record(s)")
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
@@ -782,14 +983,77 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
}
|
||||
|
||||
// Standard batch update without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range updates {
|
||||
if itemID, ok := item["id"]; ok {
|
||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
}
|
||||
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: itemIDStr,
|
||||
Data: item,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
item = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values
|
||||
for key, newValue := range item {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = item
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -857,16 +1121,80 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
}
|
||||
|
||||
// Standard batch update without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
list := make([]interface{}, 0)
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range updates {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if itemID, ok := itemMap["id"]; ok {
|
||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
}
|
||||
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: itemIDStr,
|
||||
Data: itemMap,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
itemMap = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values
|
||||
for key, newValue := range itemMap {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = itemMap
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
list = append(list, item)
|
||||
}
|
||||
}
|
||||
@@ -1078,29 +1406,161 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
h.sendResponse(w, recordToDelete, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// applyFilters applies all filters with proper grouping for OR logic
|
||||
// Groups consecutive OR filters together to ensure proper query precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (current or next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped WHERE clause
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build all conditions and collect args
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Single filter - no need for grouping
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
|
||||
// Multiple conditions - group with parentheses and OR
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return condition, args
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// Determine which method to use based on LogicOperator
|
||||
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
|
||||
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
default:
|
||||
return query
|
||||
}
|
||||
|
||||
// Apply filter with appropriate logic operator
|
||||
if useOrLogic {
|
||||
return query.WhereOr(condition, args...)
|
||||
}
|
||||
return query.Where(condition, args...)
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
@@ -1155,10 +1615,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -1328,30 +1794,7 @@ func isNullable(field reflect.StructField) bool {
|
||||
|
||||
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||
info := h.getRelationshipInfo(modelType, relationName)
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
// Convert internal type to common type
|
||||
return &common.RelationshipInfo{
|
||||
FieldName: info.fieldName,
|
||||
JSONName: info.jsonName,
|
||||
RelationType: info.relationType,
|
||||
ForeignKey: info.foreignKey,
|
||||
References: info.references,
|
||||
JoinTable: info.joinTable,
|
||||
RelatedModel: info.relatedModel,
|
||||
}
|
||||
}
|
||||
|
||||
type relationshipInfo struct {
|
||||
fieldName string
|
||||
jsonName string
|
||||
relationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||
foreignKey string
|
||||
references string
|
||||
joinTable string
|
||||
relatedModel interface{}
|
||||
return common.GetRelationshipInfo(modelType, relationName)
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
@@ -1371,7 +1814,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
for idx := range preloads {
|
||||
preload := preloads[idx]
|
||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
||||
relInfo := common.GetRelationshipInfo(modelType, preload.Relation)
|
||||
if relInfo == nil {
|
||||
logger.Warn("Relation %s not found in model", preload.Relation)
|
||||
continue
|
||||
@@ -1379,7 +1822,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
|
||||
// Use the field name (capitalized) for ORM preloading
|
||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||
relationFieldName := relInfo.fieldName
|
||||
relationFieldName := relInfo.FieldName
|
||||
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
@@ -1422,13 +1865,13 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
copy(columns, preload.Columns)
|
||||
|
||||
// Add foreign key if not already present
|
||||
if relInfo.foreignKey != "" {
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
||||
foreignKeyColumn := toSnakeCase(relInfo.foreignKey)
|
||||
foreignKeyColumn := toSnakeCase(relInfo.ForeignKey)
|
||||
|
||||
hasForeignKey := false
|
||||
for _, col := range columns {
|
||||
if col == foreignKeyColumn || col == relInfo.foreignKey {
|
||||
if col == foreignKeyColumn || col == relInfo.ForeignKey {
|
||||
hasForeignKey = true
|
||||
break
|
||||
}
|
||||
@@ -1456,6 +1899,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: preloads}
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -1474,58 +1919,6 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
||||
// Ensure we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
|
||||
if jsonName == relationName {
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
info := &relationshipInfo{
|
||||
fieldName: field.Name,
|
||||
jsonName: jsonName,
|
||||
}
|
||||
|
||||
// Parse GORM tag to determine relationship type and keys
|
||||
if strings.Contains(gormTag, "foreignKey") {
|
||||
info.foreignKey = h.extractTagValue(gormTag, "foreignKey")
|
||||
info.references = h.extractTagValue(gormTag, "references")
|
||||
|
||||
// Determine if it's belongsTo or hasMany/hasOne
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
info.relationType = "hasMany"
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
info.relationType = "belongsTo"
|
||||
}
|
||||
} else if strings.Contains(gormTag, "many2many") {
|
||||
info.relationType = "many2many"
|
||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) extractTagValue(tag, key string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, key+":") {
|
||||
return strings.TrimPrefix(part, key+":")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
@@ -1551,6 +1944,51 @@ func toSnakeCase(s string) string {
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
if h.openAPIGenerator == nil {
|
||||
|
||||
@@ -269,8 +269,6 @@ func TestToSnakeCase(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExtractTagValue(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
@@ -311,9 +309,9 @@ func TestExtractTagValue(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.extractTagValue(tt.tag, tt.key)
|
||||
result := common.ExtractTagValue(tt.tag, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
|
||||
t.Errorf("ExtractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -50,8 +50,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
@@ -98,7 +99,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -106,7 +108,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -117,7 +119,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -125,7 +128,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -137,13 +140,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
corsConfig.AllowedMethods = allowedMethods
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
// Return metadata in the OPTIONS response body
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -212,9 +216,30 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -222,15 +247,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -251,85 +277,97 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// POST route without ID
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// POST route with ID
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// GET route without ID
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
// GET route with ID
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
@@ -344,8 +382,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with bunrouter
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes with bunrouter without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -366,8 +404,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -385,8 +423,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
|
||||
func ExampleWithGORMAndAuth(db *gorm.DB) {
|
||||
// Create handler using GORM
|
||||
_ = NewHandlerWithGORM(db)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup router with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Register models
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
|
||||
func ExampleWithBunAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup routes with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
|
||||
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun database adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler with Bun
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Create bunrouter
|
||||
_ = bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with authentication
|
||||
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
@@ -214,14 +214,46 @@ x-expand: department:id,name,code
|
||||
**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation.
|
||||
|
||||
#### `x-custom-sql-join`
|
||||
Raw SQL JOIN statement.
|
||||
Custom SQL JOIN clauses for joining tables in queries.
|
||||
|
||||
**Format:** SQL JOIN clause
|
||||
**Format:** SQL JOIN clause or multiple clauses separated by `|`
|
||||
|
||||
**Single JOIN:**
|
||||
```
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||
```
|
||||
|
||||
⚠️ **Note:** Not yet fully implemented.
|
||||
**Multiple JOINs:**
|
||||
```
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS)
|
||||
- Multiple JOINs can be specified using the pipe `|` separator
|
||||
- JOINs are sanitized for security
|
||||
- Can be specified via headers or query parameters
|
||||
- **Table aliases are automatically extracted and allowed for filtering and sorting**
|
||||
|
||||
**Using Join Aliases in Filters and Sorts:**
|
||||
|
||||
When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters:
|
||||
|
||||
```
|
||||
# Join with alias
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||
|
||||
# Sort by joined table column
|
||||
x-sort: d.name,employees.id
|
||||
|
||||
# Filter by joined table column
|
||||
x-searchop-eq-d.name: Engineering
|
||||
```
|
||||
|
||||
The system automatically:
|
||||
1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`)
|
||||
2. Validates that prefixed columns (like `d.name`) refer to valid join aliases
|
||||
3. Allows these prefixed columns in filters and sorts
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type queryCacheKey struct {
|
||||
Sort []common.SortOption `json:"sort"`
|
||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||
CustomSQLJoin []string `json:"custom_sql_join,omitempty"`
|
||||
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||
Distinct bool `json:"distinct,omitempty"`
|
||||
CursorForward string `json:"cursor_forward,omitempty"`
|
||||
@@ -40,7 +41,7 @@ type cachedTotal struct {
|
||||
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||
// Includes expand, distinct, and cursor pagination options
|
||||
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
|
||||
key := queryCacheKey{
|
||||
TableName: tableName,
|
||||
@@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
CustomSQLJoin: customJoin,
|
||||
Distinct: distinct,
|
||||
CursorForward: cursorFwd,
|
||||
CursorBackward: cursorBwd,
|
||||
@@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
|
||||
@@ -435,9 +435,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Apply preloading
|
||||
logger.Debug("Total preloads to apply: %d", len(options.Preload))
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
logger.Debug("Applying preload: %s", preload.Relation)
|
||||
logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
|
||||
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
|
||||
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
@@ -463,7 +465,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Apply filters - validate and adjust for column types first
|
||||
for i := range options.Filters {
|
||||
// Group consecutive OR filters together to prevent OR logic from escaping
|
||||
for i := 0; i < len(options.Filters); {
|
||||
filter := &options.Filters[i]
|
||||
|
||||
// Validate and adjust filter based on column type
|
||||
@@ -475,8 +478,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
logicOp = "AND"
|
||||
}
|
||||
|
||||
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||
// Check if this is the start of an OR group
|
||||
if logicOp == "OR" {
|
||||
// Collect all consecutive OR filters
|
||||
orFilters := []*common.FilterOption{filter}
|
||||
orCastInfo := []ColumnCastInfo{castInfo}
|
||||
|
||||
j := i + 1
|
||||
for j < len(options.Filters) {
|
||||
nextFilter := &options.Filters[j]
|
||||
nextLogicOp := nextFilter.LogicOperator
|
||||
if nextLogicOp == "" {
|
||||
nextLogicOp = "AND"
|
||||
}
|
||||
if nextLogicOp == "OR" {
|
||||
nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model)
|
||||
orFilters = append(orFilters, nextFilter)
|
||||
orCastInfo = append(orCastInfo, nextCastInfo)
|
||||
j++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped condition
|
||||
logger.Debug("Applying OR filter group with %d conditions", len(orFilters))
|
||||
query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName)
|
||||
i = j
|
||||
} else {
|
||||
// Single AND filter - apply normally
|
||||
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
@@ -486,6 +520,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -497,13 +533,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedOr = common.EnsureOuterParentheses(sanitizedOr)
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
}
|
||||
|
||||
// If ID is provided, filter by ID
|
||||
if id != "" {
|
||||
// Apply custom SQL JOIN clauses
|
||||
if len(options.CustomSQLJoin) > 0 {
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
|
||||
// Joins are already sanitized during parsing, so we can apply them directly
|
||||
query = query.Join(joinClause)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle FetchRowNumber before applying ID filter
|
||||
// This must happen before the query to get the row position, then filter by PK
|
||||
var fetchedRowNumber *int64
|
||||
var fetchRowNumberPKValue string
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
fetchRowNumberPKValue = *options.FetchRowNumber
|
||||
|
||||
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
|
||||
if err != nil {
|
||||
logger.Error("Failed to fetch row number: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
|
||||
return
|
||||
}
|
||||
|
||||
fetchedRowNumber = &rowNum
|
||||
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||
|
||||
// Now filter the main query to this specific primary key
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), fetchRowNumberPKValue)
|
||||
} else if id != "" {
|
||||
// If ID is provided (and not FetchRowNumber), filter by ID
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
||||
|
||||
@@ -552,6 +621,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
options.Sort,
|
||||
options.CustomSQLWhere,
|
||||
options.CustomSQLOr,
|
||||
options.CustomSQLJoin,
|
||||
expandOpts,
|
||||
options.Distinct,
|
||||
options.CursorForward,
|
||||
@@ -682,7 +752,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Set row numbers on each record if the model has a RowNumber field
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
// If FetchRowNumber was used, set the fetched row number instead of offset-based
|
||||
if fetchedRowNumber != nil {
|
||||
// FetchRowNumber: set the actual row position on the record
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
} else {
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
@@ -692,21 +769,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// Fetch row number for a specific record if requested
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
pkValue := *options.FetchRowNumber
|
||||
|
||||
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to fetch row number: %v", err)
|
||||
// Don't fail the entire request, just log the warning
|
||||
} else {
|
||||
metadata.RowNumber = &rowNum
|
||||
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
||||
}
|
||||
// If FetchRowNumber was used, also set it in metadata
|
||||
if fetchedRowNumber != nil {
|
||||
metadata.RowNumber = fetchedRowNumber
|
||||
logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
|
||||
}
|
||||
|
||||
// Execute AfterRead hooks
|
||||
@@ -766,7 +832,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
// Apply ComputedQL fields if any
|
||||
if len(preload.ComputedQL) > 0 {
|
||||
// Get the base table name from the related model
|
||||
baseTableName := getTableNameFromModel(relatedModel)
|
||||
baseTableName := common.GetTableNameFromModel(relatedModel)
|
||||
|
||||
// Convert the preload relation path to the appropriate alias format
|
||||
// This is ORM-specific. Currently we only support Bun's format.
|
||||
@@ -777,7 +843,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
||||
if strings.Contains(underlyingType, "bun.DB") {
|
||||
// Use Bun's alias format: lowercase with double underscores
|
||||
preloadAlias = relationPathToBunAlias(preload.Relation)
|
||||
preloadAlias = common.RelationPathToBunAlias(preload.Relation)
|
||||
}
|
||||
// For GORM: GORM doesn't use the same alias format, and this fix
|
||||
// may not be needed since GORM handles preloads differently
|
||||
@@ -792,7 +858,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
// levels of recursive/nested preloads
|
||||
adjustedExpr := colExpr
|
||||
if baseTableName != "" && preloadAlias != "" {
|
||||
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||
adjustedExpr = common.ReplaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||
if adjustedExpr != colExpr {
|
||||
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||
colName, colExpr, adjustedExpr)
|
||||
@@ -836,6 +902,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL joins from XFiles
|
||||
if len(preload.SqlJoins) > 0 {
|
||||
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
|
||||
for _, joinClause := range preload.SqlJoins {
|
||||
sq = sq.Join(joinClause)
|
||||
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
@@ -861,10 +936,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
if len(preload.Where) > 0 {
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
// Then sanitize and allow preload table prefixes
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
|
||||
// Determine the table name to use for WHERE clause processing
|
||||
// Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
|
||||
tableName := preload.TableName
|
||||
if tableName == "" {
|
||||
tableName = reflection.ExtractTableNameOnly(preload.Relation)
|
||||
}
|
||||
|
||||
// In Bun's Relation context, table prefixes are only needed when there are JOINs
|
||||
// Without JOINs, Bun already knows which table is being queried
|
||||
whereClause := preload.Where
|
||||
if len(preload.SqlJoins) > 0 {
|
||||
// Has JOINs: add table prefixes to disambiguate columns
|
||||
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
|
||||
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
|
||||
}
|
||||
|
||||
// Sanitize the WHERE clause and allow preload table prefixes
|
||||
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -883,91 +973,85 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
})
|
||||
|
||||
// Handle recursive preloading
|
||||
if preload.Recursive && depth < 5 {
|
||||
if preload.Recursive && depth < 8 {
|
||||
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||
|
||||
// For recursive relationships, we need to get the last part of the relation path
|
||||
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
|
||||
relationParts := strings.Split(preload.Relation, ".")
|
||||
lastRelationName := relationParts[len(relationParts)-1]
|
||||
|
||||
// Create a recursive preload with the same configuration
|
||||
// but with the relation path extended
|
||||
// Generate FK-based relation name for children
|
||||
// Use RecursiveChildKey if available, otherwise fall back to RelatedKey
|
||||
recursiveFK := preload.RecursiveChildKey
|
||||
if recursiveFK == "" {
|
||||
recursiveFK = preload.RelatedKey
|
||||
}
|
||||
|
||||
recursiveRelationName := lastRelationName
|
||||
if recursiveFK != "" {
|
||||
// Check if the last relation name already contains the FK suffix
|
||||
// (this happens when XFiles already generated the FK-based name)
|
||||
fkUpper := strings.ToUpper(recursiveFK)
|
||||
expectedSuffix := "_" + fkUpper
|
||||
|
||||
if strings.HasSuffix(lastRelationName, expectedSuffix) {
|
||||
// Already has FK suffix, just reuse the same name
|
||||
recursiveRelationName = lastRelationName
|
||||
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
|
||||
} else {
|
||||
// Generate FK-based name
|
||||
recursiveRelationName = lastRelationName + expectedSuffix
|
||||
keySource := "RelatedKey"
|
||||
if preload.RecursiveChildKey != "" {
|
||||
keySource = "RecursiveChildKey"
|
||||
}
|
||||
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
|
||||
keySource, recursiveRelationName, recursiveFK)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
|
||||
preload.Relation, preload.Relation, lastRelationName)
|
||||
}
|
||||
|
||||
// Create recursive preload
|
||||
recursivePreload := preload
|
||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
|
||||
recursivePreload.Recursive = false // Prevent infinite recursion at this level
|
||||
|
||||
// Recursively apply preload until we reach depth 5
|
||||
// Use the recursive FK for child relations, not the parent's RelatedKey
|
||||
if preload.RecursiveChildKey != "" {
|
||||
recursivePreload.RelatedKey = preload.RecursiveChildKey
|
||||
}
|
||||
|
||||
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
|
||||
recursivePreload.Where = ""
|
||||
recursivePreload.Filters = []common.FilterOption{}
|
||||
logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d",
|
||||
recursivePreload.Relation, depth+1)
|
||||
|
||||
// Apply recursively up to depth 8
|
||||
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
// ALSO: Extend any child relations (like DEF) to recursive levels
|
||||
baseRelation := preload.Relation + "."
|
||||
for i := range allPreloads {
|
||||
relatedPreload := allPreloads[i]
|
||||
if strings.HasPrefix(relatedPreload.Relation, baseRelation) &&
|
||||
!strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") {
|
||||
childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation)
|
||||
|
||||
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
||||
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
||||
func relationPathToBunAlias(relationPath string) string {
|
||||
if relationPath == "" {
|
||||
return ""
|
||||
}
|
||||
// Convert to lowercase and replace dots with double underscores
|
||||
alias := strings.ToLower(relationPath)
|
||||
alias = strings.ReplaceAll(alias, ".", "__")
|
||||
return alias
|
||||
}
|
||||
extendedChildPreload := relatedPreload
|
||||
extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName
|
||||
extendedChildPreload.Recursive = false
|
||||
|
||||
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||
// with the appropriate alias for the current preload level
|
||||
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||
return sqlExpr
|
||||
}
|
||||
logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d",
|
||||
relatedPreload.Relation, extendedChildPreload.Relation, depth+1)
|
||||
|
||||
// Replace both quoted and unquoted table references
|
||||
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||
|
||||
// Pattern 1: tablename.column (unquoted)
|
||||
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||
|
||||
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getTableNameFromModel extracts the table name from a model
|
||||
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
||||
func getTableNameFromModel(model interface{}) string {
|
||||
if model == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers
|
||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Look for bun tag on embedded BaseModel
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if field.Anonymous {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.HasPrefix(bunTag, "table:") {
|
||||
return strings.TrimPrefix(bunTag, "table:")
|
||||
query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||
return strings.ToLower(modelType.Name())
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||
@@ -1177,30 +1261,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
logger.Info("Updating record in %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeUpdate hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Tx: h.db,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
logger.Error("BeforeUpdate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
data = hookCtx.Data
|
||||
|
||||
// Convert data to map
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
@@ -1234,11 +1294,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
// Variable to store the updated record
|
||||
var updatedRecord interface{}
|
||||
|
||||
// Declare hook context to be used inside and outside transaction
|
||||
var hookCtx *HookContext
|
||||
|
||||
// Process nested relations if present
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Create temporary nested processor with transaction
|
||||
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found with ID: %v", targetID)
|
||||
}
|
||||
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||
}
|
||||
|
||||
// Extract nested relations if present (but don't process them yet)
|
||||
var nestedRelations map[string]interface{}
|
||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
||||
@@ -1251,15 +1334,54 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
nestedRelations = relations
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx = &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Tx: tx,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Data: dataMap,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
dataMap = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||
for key, newValue := range dataMap {
|
||||
// Skip if the value is nil
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if the value is an empty string
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update the existing map with the new value
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
// Ensure ID is in the data map for the update
|
||||
dataMap[pkName] = targetID
|
||||
existingMap[pkName] = targetID
|
||||
dataMap = existingMap
|
||||
|
||||
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
||||
// Get the type of the model, handling both pointer and non-pointer types
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
modelType = reflection.GetPointerElement(modelType)
|
||||
modelInstance := reflect.New(modelType).Interface()
|
||||
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
||||
return fmt.Errorf("failed to populate model from data: %w", err)
|
||||
@@ -1297,7 +1419,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
// Fetch the updated record to return the new values
|
||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||
}
|
||||
@@ -1563,9 +1685,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
// First, fetch the record that will be deleted
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
modelType = reflection.GetPointerElement(modelType)
|
||||
recordToDelete := reflect.New(modelType).Interface()
|
||||
|
||||
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
@@ -1825,10 +1945,46 @@ func (h *Handler) processChildRelationsForField(
|
||||
parentIDs[baseName] = parentID
|
||||
}
|
||||
|
||||
// Determine which field name to use for setting parent ID in child data
|
||||
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
}
|
||||
} else {
|
||||
// Fallback: use parent's primary key name
|
||||
parentPKName := reflection.GetPrimaryKeyName(parentModelType)
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName)
|
||||
if foreignKeyFieldName == "" {
|
||||
foreignKeyFieldName = strings.ToLower(parentPKName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||
if childPKFieldName == "" {
|
||||
childPKFieldName = strings.ToLower(childPKName)
|
||||
}
|
||||
|
||||
logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s, childPK=%s",
|
||||
foreignKeyFieldName, parentID, relInfo.ForeignKey, childPKFieldName)
|
||||
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object
|
||||
// Single related object - add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
v[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process single relation: %w", err)
|
||||
@@ -1838,6 +1994,14 @@ func (h *Handler) processChildRelationsForField(
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
// Add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||
@@ -1848,6 +2012,14 @@ func (h *Handler) processChildRelationsForField(
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
// Add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||
@@ -1861,11 +2033,18 @@ func (h *Handler) processChildRelationsForField(
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTableNameForRelatedModel gets the table name for a related model
|
||||
// getTableNameForRelatedModel gets the table name for a related model.
|
||||
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
||||
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
||||
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := provider.TableName()
|
||||
if tableName != "" {
|
||||
if schema, table := h.parseTableName(tableName); schema != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schema, table)
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
}
|
||||
@@ -1965,6 +2144,99 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
}
|
||||
}
|
||||
|
||||
// applyOrFilterGroup applies a group of OR filters as a single grouped condition
|
||||
// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping
|
||||
func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build individual filter conditions
|
||||
conditions := []string{}
|
||||
args := []interface{}{}
|
||||
|
||||
for i, filter := range filters {
|
||||
// Qualify the column name with table name if not already qualified
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||
if castInfo[i].NeedsCast {
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||
}
|
||||
|
||||
// Build the condition based on operator
|
||||
condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Join all conditions with OR and wrap in parentheses
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
logger.Debug("Applying grouped OR conditions: %s", groupedCondition)
|
||||
|
||||
// Apply as AND condition (the OR is already inside the parentheses)
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||
switch strings.ToLower(filter.Operator) {
|
||||
case "eq", "equals":
|
||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "neq", "not_equals", "ne":
|
||||
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gt", "greater_than":
|
||||
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gte", "greater_than_equals", "ge":
|
||||
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lt", "less_than":
|
||||
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lte", "less_than_equals", "le":
|
||||
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "in":
|
||||
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
}
|
||||
logger.Warn("Invalid BETWEEN filter value format")
|
||||
return "", nil
|
||||
case "between_inclusive":
|
||||
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
}
|
||||
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
|
||||
return "", nil
|
||||
case "is_null", "isnull":
|
||||
// Check for NULL values - don't use cast for NULL checks
|
||||
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||
return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil
|
||||
case "is_not_null", "isnotnull":
|
||||
// Check for NOT NULL values - don't use cast for NULL checks
|
||||
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil
|
||||
default:
|
||||
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
|
||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
}
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
@@ -2017,10 +2289,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -2342,21 +2620,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clauses from filters
|
||||
whereClauses := make([]string, 0)
|
||||
for i := range options.Filters {
|
||||
filter := &options.Filters[i]
|
||||
whereClause := h.buildFilterSQL(filter, tableName)
|
||||
if whereClause != "" {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
||||
}
|
||||
}
|
||||
|
||||
// Combine WHERE clauses
|
||||
whereSQL := ""
|
||||
if len(whereClauses) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
||||
}
|
||||
// Build WHERE clause from filters with proper OR grouping
|
||||
whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
|
||||
|
||||
// Add custom SQL WHERE if provided
|
||||
if options.CustomSQLWhere != "" {
|
||||
@@ -2404,19 +2669,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
|
||||
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
||||
whereInfo := "none"
|
||||
if whereSQL != "" {
|
||||
whereInfo = whereSQL
|
||||
}
|
||||
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
|
||||
// Groups consecutive OR filters together to ensure proper SQL precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
|
||||
if len(filters) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var groups []string
|
||||
i := 0
|
||||
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []string{}
|
||||
|
||||
// Add current filter
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
|
||||
// Collect remaining OR filters
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
filterSQL := h.buildFilterSQL(&filters[j], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
j++
|
||||
}
|
||||
|
||||
// Group OR filters with parentheses
|
||||
if len(orGroup) > 0 {
|
||||
if len(orGroup) == 1 {
|
||||
groups = append(groups, orGroup[0])
|
||||
} else {
|
||||
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
|
||||
}
|
||||
}
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
groups = append(groups, filterSQL)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(groups, " AND ")
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
@@ -2537,10 +2869,10 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
filteredExpand := expand
|
||||
|
||||
// Get the relationship info for this expand relation
|
||||
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
|
||||
if relInfo != nil && relInfo.relatedModel != nil {
|
||||
relInfo := common.GetRelationshipInfo(modelType, expand.Relation)
|
||||
if relInfo != nil && relInfo.RelatedModel != nil {
|
||||
// Create a validator for the related model
|
||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||
expandValidator := common.NewColumnValidator(relInfo.RelatedModel)
|
||||
// Filter columns using the related model's validator
|
||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||
|
||||
@@ -2617,110 +2949,7 @@ func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model in
|
||||
|
||||
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||
info := h.getRelationshipInfo(modelType, relationName)
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
// Convert internal type to common type
|
||||
return &common.RelationshipInfo{
|
||||
FieldName: info.fieldName,
|
||||
JSONName: info.jsonName,
|
||||
RelationType: info.relationType,
|
||||
ForeignKey: info.foreignKey,
|
||||
References: info.references,
|
||||
JoinTable: info.joinTable,
|
||||
RelatedModel: info.relatedModel,
|
||||
}
|
||||
}
|
||||
|
||||
type relationshipInfo struct {
|
||||
fieldName string
|
||||
jsonName string
|
||||
relationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||
foreignKey string
|
||||
references string
|
||||
joinTable string
|
||||
relatedModel interface{}
|
||||
}
|
||||
|
||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
||||
// Ensure we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
|
||||
if jsonName == relationName {
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
info := &relationshipInfo{
|
||||
fieldName: field.Name,
|
||||
jsonName: jsonName,
|
||||
}
|
||||
|
||||
// Parse GORM tag to determine relationship type and keys
|
||||
if strings.Contains(gormTag, "foreignKey") {
|
||||
info.foreignKey = h.extractTagValue(gormTag, "foreignKey")
|
||||
info.references = h.extractTagValue(gormTag, "references")
|
||||
|
||||
// Determine if it's belongsTo or hasMany/hasOne
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
info.relationType = "hasMany"
|
||||
// Get the element type for slice
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
info.relationType = "belongsTo"
|
||||
elemType := field.Type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(gormTag, "many2many") {
|
||||
info.relationType = "many2many"
|
||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
||||
// Get the element type for many2many (always slice)
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Field has no GORM relationship tags, so it's not a relation
|
||||
return nil
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) extractTagValue(tag, key string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, key+":") {
|
||||
return strings.TrimPrefix(part, key+":")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return common.GetRelationshipInfo(modelType, relationName)
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
|
||||
@@ -26,7 +26,9 @@ type ExtendedRequestOptions struct {
|
||||
CustomSQLOr string
|
||||
|
||||
// Joins
|
||||
Expand []ExpandOption
|
||||
Expand []ExpandOption
|
||||
CustomSQLJoin []string // Custom SQL JOIN clauses
|
||||
JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation
|
||||
|
||||
// Advanced features
|
||||
AdvancedSQL map[string]string // Column -> SQL expression
|
||||
@@ -46,7 +48,8 @@ type ExtendedRequestOptions struct {
|
||||
AtomicTransaction bool
|
||||
|
||||
// X-Files configuration - comprehensive query options as a single JSON object
|
||||
XFiles *XFiles
|
||||
XFiles *XFiles
|
||||
XFilesPresent bool // Flag to indicate if X-Files header was provided
|
||||
}
|
||||
|
||||
// ExpandOption represents a relation expansion configuration
|
||||
@@ -111,6 +114,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
AdvancedSQL: make(map[string]string),
|
||||
ComputedQL: make(map[string]string),
|
||||
Expand: make([]ExpandOption, 0),
|
||||
CustomSQLJoin: make([]string, 0),
|
||||
ResponseFormat: "simple", // Default response format
|
||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||
}
|
||||
@@ -185,8 +189,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
case strings.HasPrefix(key, "x-expand"):
|
||||
h.parseExpand(&options, decodedValue)
|
||||
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||
// TODO: Implement custom SQL join
|
||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||
h.parseCustomSQLJoin(&options, decodedValue)
|
||||
|
||||
// Sorting & Pagination
|
||||
case strings.HasPrefix(key, "x-sort"):
|
||||
@@ -272,7 +275,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
}
|
||||
|
||||
// Resolve relation names (convert table names to field names) if model is provided
|
||||
if model != nil {
|
||||
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
|
||||
if model != nil && !options.XFilesPresent {
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
|
||||
@@ -354,6 +358,12 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
|
||||
operator := parts[0]
|
||||
colName := parts[1]
|
||||
|
||||
if strings.HasPrefix(colName, "cql") {
|
||||
// Computed column - Will not filter on it
|
||||
logger.Warn("Search operators on computed columns are not supported: %s", colName)
|
||||
return
|
||||
}
|
||||
|
||||
// Map operator names to filter operators
|
||||
filterOp := h.mapSearchOperator(colName, operator, value)
|
||||
|
||||
@@ -489,6 +499,101 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseCustomSQLJoin parses x-custom-sql-join header
|
||||
// Format: Single JOIN clause or multiple JOIN clauses separated by |
|
||||
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
|
||||
// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id"
|
||||
func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Split by | for multiple joins
|
||||
joins := strings.Split(value, "|")
|
||||
for _, joinStr := range joins {
|
||||
joinStr = strings.TrimSpace(joinStr)
|
||||
if joinStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Basic validation: should contain "JOIN" keyword
|
||||
upperJoin := strings.ToUpper(joinStr)
|
||||
if !strings.Contains(upperJoin, "JOIN") {
|
||||
logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Sanitize the join clause using common.SanitizeWhereClause
|
||||
// Note: This is basic sanitization - in production you may want stricter validation
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("Custom SQL join failed sanitization: %s", joinStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract table alias from the JOIN clause
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
options.JoinAliases = append(options.JoinAliases, alias)
|
||||
// Also add to the embedded RequestOptions for validation
|
||||
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
|
||||
logger.Debug("Extracted join alias: %s", alias)
|
||||
}
|
||||
|
||||
logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
|
||||
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
|
||||
}
|
||||
}
|
||||
|
||||
// extractJoinAlias extracts the table alias from a JOIN clause
|
||||
// Examples:
|
||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||
// - "JOIN roles r ON ..." -> "r"
|
||||
func extractJoinAlias(joinClause string) string {
|
||||
// Pattern: JOIN table_name [AS] alias ON ...
|
||||
// We need to extract the alias (word before ON)
|
||||
|
||||
upperJoin := strings.ToUpper(joinClause)
|
||||
|
||||
// Find the "JOIN" keyword position
|
||||
joinIdx := strings.Index(upperJoin, "JOIN")
|
||||
if joinIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the "ON" keyword position
|
||||
onIdx := strings.Index(upperJoin, " ON ")
|
||||
if onIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Extract the part between JOIN and ON
|
||||
betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx])
|
||||
|
||||
// Split by spaces to get words
|
||||
words := strings.Fields(betweenJoinAndOn)
|
||||
if len(words) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If there's an AS keyword, the alias is after it
|
||||
for i, word := range words {
|
||||
if strings.EqualFold(word, "AS") && i+1 < len(words) {
|
||||
return words[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, the alias is the last word (if there are 2+ words)
|
||||
// Format: "table_name alias" or just "table_name"
|
||||
if len(words) >= 2 {
|
||||
return words[len(words)-1]
|
||||
}
|
||||
|
||||
// Only one word means it's just the table name, no alias
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSorting parses x-sort header
|
||||
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
|
||||
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
||||
@@ -590,6 +695,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
|
||||
|
||||
// Store the original XFiles for reference
|
||||
options.XFiles = &xfiles
|
||||
options.XFilesPresent = true // Mark that X-Files header was provided
|
||||
|
||||
// Map XFiles fields to ExtendedRequestOptions
|
||||
|
||||
@@ -881,11 +987,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
return
|
||||
}
|
||||
|
||||
// Store the table name as-is for now - it will be resolved to field name later
|
||||
// when we have the model instance available
|
||||
relationPath := xfile.TableName
|
||||
// Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
|
||||
// Fall back to TableName if Prefix is not specified
|
||||
relationName := xfile.Prefix
|
||||
if relationName == "" {
|
||||
relationName = xfile.TableName
|
||||
}
|
||||
|
||||
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
|
||||
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
|
||||
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
|
||||
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
|
||||
// Check if this is a self-referencing recursive relation (same table as parent)
|
||||
// by comparing the last part of basePath with the current prefix
|
||||
basePathParts := strings.Split(basePath, ".")
|
||||
lastPrefix := basePathParts[len(basePathParts)-1]
|
||||
|
||||
if lastPrefix == relationName {
|
||||
// This is a recursive self-reference, use FK-based name
|
||||
fkUpper := strings.ToUpper(xfile.RelatedKey)
|
||||
relationName = relationName + "_" + fkUpper
|
||||
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
|
||||
}
|
||||
}
|
||||
|
||||
relationPath := relationName
|
||||
if basePath != "" {
|
||||
relationPath = basePath + "." + xfile.TableName
|
||||
relationPath = basePath + "." + relationName
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
||||
@@ -893,6 +1021,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
// Create PreloadOption from XFiles configuration
|
||||
preloadOpt := common.PreloadOption{
|
||||
Relation: relationPath,
|
||||
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
|
||||
Columns: xfile.Columns,
|
||||
OmitColumns: xfile.OmitColumns,
|
||||
}
|
||||
@@ -935,12 +1064,12 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
// Add WHERE clause if SQL conditions specified
|
||||
whereConditions := make([]string, 0)
|
||||
if len(xfile.SqlAnd) > 0 {
|
||||
// Process each SQL condition: add table prefixes and sanitize
|
||||
// Process each SQL condition
|
||||
// Note: We don't add table prefixes here because they're only needed for JOINs
|
||||
// The handler will add prefixes later if SqlJoins are present
|
||||
for _, sqlCond := range xfile.SqlAnd {
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
|
||||
// Then sanitize the condition
|
||||
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
|
||||
// Sanitize the condition without adding prefixes
|
||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
|
||||
if sanitizedCond != "" {
|
||||
whereConditions = append(whereConditions, sanitizedCond)
|
||||
}
|
||||
@@ -985,13 +1114,72 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Transfer SqlJoins from XFiles to PreloadOption
|
||||
if len(xfile.SqlJoins) > 0 {
|
||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||
|
||||
for _, joinClause := range xfile.SqlJoins {
|
||||
// Sanitize the join clause
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||
continue
|
||||
}
|
||||
|
||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||
|
||||
// Extract join alias for validation
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||
}
|
||||
|
||||
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||
// and store the recursive child's RelatedKey for recursion generation
|
||||
hasRecursiveChild := false
|
||||
if len(xfile.ChildTables) > 0 {
|
||||
for _, childTable := range xfile.ChildTables {
|
||||
if childTable.Recursive && childTable.TableName == xfile.TableName {
|
||||
hasRecursiveChild = true
|
||||
preloadOpt.Recursive = true
|
||||
preloadOpt.RecursiveChildKey = childTable.RelatedKey
|
||||
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
|
||||
relationPath, childTable.RelatedKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
|
||||
if xfile.Recursive && basePath != "" {
|
||||
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
|
||||
// Still process its parent/child tables for relations like DEF
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
return
|
||||
}
|
||||
|
||||
// Add the preload option
|
||||
options.Preload = append(options.Preload, preloadOpt)
|
||||
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
|
||||
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
|
||||
|
||||
// Recursively process nested ParentTables and ChildTables
|
||||
if xfile.Recursive {
|
||||
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
// Skip processing child tables if we already detected and handled a recursive child
|
||||
if hasRecursiveChild {
|
||||
logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
|
||||
// But still process parent tables
|
||||
if len(xfile.ParentTables) > 0 {
|
||||
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
|
||||
for _, parentTable := range xfile.ParentTables {
|
||||
h.addXFilesPreload(parentTable, options, relationPath)
|
||||
}
|
||||
}
|
||||
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestDecodeHeaderValue(t *testing.T) {
|
||||
@@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
options := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Preload: make([]common.PreloadOption, 0),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an XFiles with SqlJoins
|
||||
xfile := &XFiles{
|
||||
TableName: "users",
|
||||
SqlJoins: []string{
|
||||
"LEFT JOIN departments d ON d.id = users.department_id",
|
||||
"INNER JOIN roles r ON r.id = users.role_id",
|
||||
},
|
||||
FilterFields: []struct {
|
||||
Field string `json:"field"`
|
||||
Value string `json:"value"`
|
||||
Operator string `json:"operator"`
|
||||
}{
|
||||
{Field: "d.active", Value: "true", Operator: "eq"},
|
||||
{Field: "r.name", Value: "admin", Operator: "eq"},
|
||||
},
|
||||
}
|
||||
|
||||
// Add the XFiles preload
|
||||
handler.addXFilesPreload(xfile, options, "")
|
||||
|
||||
// Verify that a preload was added
|
||||
if len(options.Preload) != 1 {
|
||||
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify relation name
|
||||
if preload.Relation != "users" {
|
||||
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
|
||||
}
|
||||
|
||||
// Verify SqlJoins were transferred
|
||||
if len(preload.SqlJoins) != 2 {
|
||||
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
|
||||
}
|
||||
|
||||
// Verify JoinAliases were extracted
|
||||
if len(preload.JoinAliases) != 2 {
|
||||
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
|
||||
}
|
||||
|
||||
// Verify the aliases are correct
|
||||
expectedAliases := []string{"d", "r"}
|
||||
for i, expected := range expectedAliases {
|
||||
if preload.JoinAliases[i] != expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify filters were added
|
||||
if len(preload.Filters) != 2 {
|
||||
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
|
||||
}
|
||||
|
||||
// Verify filter columns reference joined tables
|
||||
if preload.Filters[0].Column != "d.active" {
|
||||
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
|
||||
}
|
||||
if preload.Filters[1].Column != "r.name" {
|
||||
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJoinAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
joinClause string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN with alias",
|
||||
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
|
||||
expected: "d",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN with AS keyword",
|
||||
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
|
||||
expected: "u",
|
||||
},
|
||||
{
|
||||
name: "JOIN without alias",
|
||||
joinClause: "JOIN roles ON roles.id = users.role_id",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Complex join with multiple conditions",
|
||||
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
|
||||
expected: "p",
|
||||
},
|
||||
{
|
||||
name: "Invalid join (no ON clause)",
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractJoinAlias(tt.joinClause)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
||||
// - parseSelectFields
|
||||
// - parseFieldFilter
|
||||
|
||||
110
pkg/restheadspec/preload_tablename_test.go
Normal file
110
pkg/restheadspec/preload_tablename_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestPreloadOption_TableName verifies that TableName field is properly used
|
||||
// when provided in PreloadOption for WHERE clause processing
|
||||
func TestPreloadOption_TableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
preload common.PreloadOption
|
||||
expectedTable string
|
||||
}{
|
||||
{
|
||||
name: "TableName provided explicitly",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "mastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
expectedTable: "mastertaskitem",
|
||||
},
|
||||
{
|
||||
name: "TableName empty, should use empty string",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
expectedTable: "",
|
||||
},
|
||||
{
|
||||
name: "Simple relation without nested path",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "Users",
|
||||
TableName: "users",
|
||||
Where: "active = true",
|
||||
},
|
||||
expectedTable: "users",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test that the TableName field stores the correct value
|
||||
if tt.preload.TableName != tt.expectedTable {
|
||||
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
|
||||
}
|
||||
|
||||
// Verify that when TableName is provided, it should be used instead of extracting from relation
|
||||
tableName := tt.preload.TableName
|
||||
if tableName == "" {
|
||||
// This simulates the fallback logic in handler.go
|
||||
// In reality, reflection.ExtractTableNameOnly would be called
|
||||
tableName = tt.expectedTable
|
||||
}
|
||||
|
||||
if tableName != tt.expectedTable {
|
||||
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestXFilesPreload_StoresTableName verifies that XFiles processing
|
||||
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
|
||||
func TestXFilesPreload_StoresTableName(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
xfiles := &XFiles{
|
||||
TableName: "mastertaskitem",
|
||||
Prefix: "MAL",
|
||||
PrimaryKey: "rid_mastertaskitem",
|
||||
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
|
||||
Recursive: false, // Changed from true (recursive children are now skipped)
|
||||
SqlAnd: []string{"rid_parentmastertaskitem is null"},
|
||||
}
|
||||
|
||||
options := &ExtendedRequestOptions{}
|
||||
|
||||
// Process XFiles
|
||||
handler.addXFilesPreload(xfiles, options, "MTL")
|
||||
|
||||
// Verify that a preload was added
|
||||
if len(options.Preload) == 0 {
|
||||
t.Fatal("Expected at least one preload to be added")
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify the table name is stored
|
||||
if preload.TableName != "mastertaskitem" {
|
||||
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
|
||||
}
|
||||
|
||||
// Verify the relation path includes the prefix
|
||||
expectedRelation := "MTL.MAL"
|
||||
if preload.Relation != expectedRelation {
|
||||
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
|
||||
}
|
||||
|
||||
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
|
||||
expectedWhere := "rid_parentmastertaskitem is null"
|
||||
if preload.Where != expectedWhere {
|
||||
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
|
||||
}
|
||||
}
|
||||
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
|
||||
// to WHERE clauses when SqlJoins are present
|
||||
func TestPreloadWhereClause_WithJoins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
sqlJoins []string
|
||||
expectedPrefix bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "No joins - no prefix needed",
|
||||
where: "status = 'active'",
|
||||
sqlJoins: []string{},
|
||||
expectedPrefix: false,
|
||||
description: "Without JOINs, Bun knows the table context",
|
||||
},
|
||||
{
|
||||
name: "Has joins - prefix needed",
|
||||
where: "status = 'active'",
|
||||
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
|
||||
expectedPrefix: true,
|
||||
description: "With JOINs, table prefix disambiguates columns",
|
||||
},
|
||||
{
|
||||
name: "Already has prefix - no change",
|
||||
where: "users.status = 'active'",
|
||||
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
|
||||
expectedPrefix: true,
|
||||
description: "Existing prefix should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// This test documents the expected behavior
|
||||
// The actual logic is in handler.go lines 916-937
|
||||
|
||||
hasJoins := len(tt.sqlJoins) > 0
|
||||
if hasJoins != tt.expectedPrefix {
|
||||
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
|
||||
hasJoins, tt.expectedPrefix)
|
||||
}
|
||||
|
||||
t.Logf("%s: %s", tt.name, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
|
||||
// results in table prefixes being added to WHERE clauses
|
||||
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
xfiles := &XFiles{
|
||||
TableName: "users",
|
||||
Prefix: "USR",
|
||||
PrimaryKey: "id",
|
||||
SqlAnd: []string{"status = 'active'"},
|
||||
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
|
||||
}
|
||||
|
||||
options := &ExtendedRequestOptions{}
|
||||
handler.addXFilesPreload(xfiles, options, "")
|
||||
|
||||
if len(options.Preload) == 0 {
|
||||
t.Fatal("Expected at least one preload to be added")
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify SqlJoins were stored
|
||||
if len(preload.SqlJoins) != 1 {
|
||||
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
|
||||
}
|
||||
|
||||
// Verify WHERE clause does NOT have prefix yet (added later in handler)
|
||||
expectedWhere := "status = 'active'"
|
||||
if preload.Where != expectedWhere {
|
||||
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
|
||||
}
|
||||
|
||||
// Note: The handler will add the prefix when it sees SqlJoins
|
||||
// This is tested in the handler itself, not during XFiles parsing
|
||||
}
|
||||
@@ -301,6 +301,163 @@ func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL JOIN from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) == 0 {
|
||||
t.Error("Expected CustomSQLJoin to be set")
|
||||
return
|
||||
}
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
expected := `LEFT JOIN departments d ON d.id = employees.department_id`
|
||||
if options.CustomSQLJoin[0] != expected {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse multiple custom SQL JOINs from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) != 2 {
|
||||
t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
expected1 := `LEFT JOIN departments d ON d.id = e.dept_id`
|
||||
expected2 := `INNER JOIN roles r ON r.id = e.role_id`
|
||||
if options.CustomSQLJoin[0] != expected1 {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0])
|
||||
}
|
||||
if options.CustomSQLJoin[1] != expected2 {
|
||||
t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL JOIN from headers",
|
||||
headers: map[string]string{
|
||||
"X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) == 0 {
|
||||
t.Error("Expected CustomSQLJoin to be set from header")
|
||||
return
|
||||
}
|
||||
expected := `LEFT JOIN users u ON u.id = posts.user_id`
|
||||
if options.CustomSQLJoin[0] != expected {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Extract aliases from custom SQL JOIN",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.JoinAliases) == 0 {
|
||||
t.Error("Expected JoinAliases to be extracted")
|
||||
return
|
||||
}
|
||||
if len(options.JoinAliases) != 1 {
|
||||
t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases))
|
||||
return
|
||||
}
|
||||
if options.JoinAliases[0] != "d" {
|
||||
t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0])
|
||||
}
|
||||
// Also check that it's in the embedded RequestOptions
|
||||
if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias to also be in RequestOptions.JoinAliases")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Extract multiple aliases from multiple custom SQL JOINs",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.JoinAliases) != 2 {
|
||||
t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases))
|
||||
return
|
||||
}
|
||||
expectedAliases := []string{"d", "r"}
|
||||
for i, expected := range expectedAliases {
|
||||
if options.JoinAliases[i] != expected {
|
||||
t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i])
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom JOIN with sort on joined table",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
"x-sort": "d.name,employees.id",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
// Verify join was added
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
// Verify alias was extracted
|
||||
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias 'd' to be extracted")
|
||||
return
|
||||
}
|
||||
// Verify sort was parsed
|
||||
if len(options.Sort) != 2 {
|
||||
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||
return
|
||||
}
|
||||
if options.Sort[0].Column != "d.name" {
|
||||
t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column)
|
||||
}
|
||||
if options.Sort[1].Column != "employees.id" {
|
||||
t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom JOIN with filter on joined table",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
"x-searchop-eq-d.name": "Engineering",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
// Verify join was added
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Error("Expected 1 custom SQL join")
|
||||
return
|
||||
}
|
||||
// Verify alias was extracted
|
||||
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias 'd' to be extracted")
|
||||
return
|
||||
}
|
||||
// Verify filter was parsed
|
||||
if len(options.Filters) != 1 {
|
||||
t.Errorf("Expected 1 filter, got %d", len(options.Filters))
|
||||
return
|
||||
}
|
||||
if options.Filters[0].Column != "d.name" {
|
||||
t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column)
|
||||
}
|
||||
if options.Filters[0].Operator != "eq" {
|
||||
t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -395,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function
|
||||
func TestCustomJoinAliasExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
join string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN with alias",
|
||||
join: "LEFT JOIN departments d ON d.id = employees.department_id",
|
||||
expected: "d",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN with AS keyword",
|
||||
join: "INNER JOIN users AS u ON u.id = posts.user_id",
|
||||
expected: "u",
|
||||
},
|
||||
{
|
||||
name: "Simple JOIN with alias",
|
||||
join: "JOIN roles r ON r.id = user_roles.role_id",
|
||||
expected: "r",
|
||||
},
|
||||
{
|
||||
name: "JOIN without alias (just table name)",
|
||||
join: "JOIN departments ON departments.id = employees.dept_id",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN with alias",
|
||||
join: "RIGHT JOIN orders o ON o.customer_id = customers.id",
|
||||
expected: "o",
|
||||
},
|
||||
{
|
||||
name: "FULL OUTER JOIN with AS",
|
||||
join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id",
|
||||
expected: "p",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractJoinAlias(tt.join)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||
|
||||
391
pkg/restheadspec/recursive_preload_test.go
Normal file
391
pkg/restheadspec/recursive_preload_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
//go:build !integration
|
||||
// +build !integration
|
||||
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestRecursivePreloadClearsWhereClause tests that recursive preloads
|
||||
// correctly clear the WHERE clause from the parent level to allow
|
||||
// Bun to use foreign key relationships for loading children
|
||||
func TestRecursivePreloadClearsWhereClause(t *testing.T) {
|
||||
// Create a mock handler
|
||||
handler := &Handler{}
|
||||
|
||||
// Create a preload option with a WHERE clause that filters root items
|
||||
// This simulates the xfiles use case where the first level has a filter
|
||||
// like "rid_parentmastertaskitem is null" to get root items
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MastertaskItems",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
Filters: []common.FilterOption{
|
||||
{
|
||||
Column: "rid_parentmastertaskitem",
|
||||
Operator: "is null",
|
||||
Value: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a mock query that tracks operations
|
||||
mockQuery := &mockSelectQuery{
|
||||
operations: []string{},
|
||||
}
|
||||
|
||||
// Apply the recursive preload at depth 0
|
||||
// This should:
|
||||
// 1. Apply the initial preload with the WHERE clause
|
||||
// 2. Create a recursive preload without the WHERE clause
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
// Verify the mock query received the operations
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Check that we have at least 2 PreloadRelation calls:
|
||||
// 1. The initial "MastertaskItems" with WHERE clause
|
||||
// 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause
|
||||
preloadCount := 0
|
||||
recursivePreloadFound := false
|
||||
whereAppliedToRecursive := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MastertaskItems" {
|
||||
preloadCount++
|
||||
}
|
||||
if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" {
|
||||
recursivePreloadFound = true
|
||||
}
|
||||
// Check if WHERE was applied to the recursive preload (it shouldn't be)
|
||||
if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound {
|
||||
whereAppliedToRecursive = true
|
||||
}
|
||||
}
|
||||
|
||||
if preloadCount < 1 {
|
||||
t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount)
|
||||
}
|
||||
|
||||
if !recursivePreloadFound {
|
||||
t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if whereAppliedToRecursive {
|
||||
t.Error("WHERE clause should not be applied to recursive preload levels")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecursivePreloadWithChildRelations tests that child relations
|
||||
// (like DEF in MAL.DEF) are properly extended to recursive levels
|
||||
func TestRecursivePreloadWithChildRelations(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Create the main recursive preload
|
||||
recursivePreload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
}
|
||||
|
||||
// Create a child relation that should be extended
|
||||
childPreload := common.PreloadOption{
|
||||
Relation: "MAL.DEF",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{
|
||||
operations: []string{},
|
||||
}
|
||||
|
||||
allPreloads := []common.PreloadOption{recursivePreload, childPreload}
|
||||
|
||||
// Apply both preloads - the child preload should be extended when the recursive one processes
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0)
|
||||
|
||||
// Also need to apply the child preload separately (as would happen in normal flow)
|
||||
result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Check that the child relation was extended to recursive levels
|
||||
// We should see:
|
||||
// - MAL (with WHERE)
|
||||
// - MAL.DEF
|
||||
// - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE)
|
||||
// - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic)
|
||||
foundMALDEF := false
|
||||
foundRecursiveMAL := false
|
||||
foundMALMALDEF := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.DEF" {
|
||||
foundMALDEF = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundRecursiveMAL = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||
foundMALMALDEF = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundMALDEF {
|
||||
t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if !foundRecursiveMAL {
|
||||
t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if !foundMALMALDEF {
|
||||
t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive
|
||||
// preload generates the correct FK-based relation name using RelatedKey
|
||||
func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Test case 1: With RelatedKey - should generate FK-based name
|
||||
t.Run("WithRelatedKey", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Should generate MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||
foundCorrectRelation := false
|
||||
foundIncorrectRelation := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundCorrectRelation = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL" {
|
||||
foundIncorrectRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCorrectRelation {
|
||||
t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if foundIncorrectRelation {
|
||||
t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified")
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Without RelatedKey - should fallback to old behavior
|
||||
t.Run("WithoutRelatedKey", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
// No RelatedKey
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Should fallback to MAL.MAL
|
||||
foundFallback := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL" {
|
||||
foundFallback = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFallback {
|
||||
t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 3: Depth limit of 8
|
||||
t.Run("DepthLimit", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
|
||||
// Start at depth 7 - should create one more level
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth8 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth8 = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundDepth8 {
|
||||
t.Error("Expected to create recursive level at depth 8")
|
||||
}
|
||||
|
||||
// Start at depth 8 - should NOT create another level
|
||||
mockQuery2 := &mockSelectQuery{operations: []string{}}
|
||||
result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8)
|
||||
mock2 := result2.(*mockSelectQuery)
|
||||
|
||||
foundDepth9 := false
|
||||
for _, op := range mock2.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth9 = true
|
||||
}
|
||||
}
|
||||
|
||||
if foundDepth9 {
|
||||
t.Error("Should NOT create recursive level beyond depth 8")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockSelectQuery implements common.SelectQuery for testing
|
||||
type mockSelectQuery struct {
|
||||
operations []string
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Model")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Table:"+table)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
for _, col := range columns {
|
||||
m.operations = append(m.operations, "Column:"+col)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Where:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereOr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereIn:"+column)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Order:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Limit")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Offset")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Join:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Group")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Having:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Preload:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||
// Apply the preload modifiers
|
||||
for _, fn := range apply {
|
||||
fn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
m.operations = append(m.operations, "Scan")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||
m.operations = append(m.operations, "ScanModel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
m.operations = append(m.operations, "Count")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
m.operations = append(m.operations, "Exists")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetModel() interface{} {
|
||||
return nil
|
||||
}
|
||||
@@ -32,6 +32,7 @@
|
||||
// - X-Clean-JSON: Boolean to remove null/empty fields
|
||||
// - X-Custom-SQL-Where: Custom SQL WHERE clause (AND)
|
||||
// - X-Custom-SQL-Or: Custom SQL WHERE clause (OR)
|
||||
// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple)
|
||||
//
|
||||
// # Usage Example
|
||||
//
|
||||
@@ -103,8 +104,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
@@ -161,7 +163,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -169,7 +172,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -180,7 +183,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -188,7 +192,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -200,13 +204,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
corsConfig.AllowedMethods = allowedMethods
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
// Return metadata in the OPTIONS response body
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -275,9 +280,30 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -285,15 +311,8 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -315,135 +334,155 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// GET and POST for /{schema}/{entity}
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
|
||||
|
||||
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
|
||||
|
||||
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// Metadata endpoint
|
||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
@@ -458,8 +497,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
@@ -479,7 +518,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
|
||||
@@ -2,6 +2,8 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestParseModelName(t *testing.T) {
|
||||
@@ -112,3 +114,88 @@ func TestNewStandardBunRouter(t *testing.T) {
|
||||
t.Error("Expected router to be created, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTagValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Extract existing key",
|
||||
tag: "json:name;validate:required",
|
||||
key: "json",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
name: "Extract key with spaces",
|
||||
tag: "json:name ; validate:required",
|
||||
key: "validate",
|
||||
expected: "required",
|
||||
},
|
||||
{
|
||||
name: "Extract key at end",
|
||||
tag: "json:name;validate:required;db:column_name",
|
||||
key: "db",
|
||||
expected: "column_name",
|
||||
},
|
||||
{
|
||||
name: "Extract key at beginning",
|
||||
tag: "primary:true;json:id;db:user_id",
|
||||
key: "primary",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Key not found",
|
||||
tag: "json:name;validate:required",
|
||||
key: "db",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty tag",
|
||||
tag: "",
|
||||
key: "json",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single key-value pair",
|
||||
tag: "json:name",
|
||||
key: "json",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
name: "Key with empty value",
|
||||
tag: "json:;validate:required",
|
||||
key: "json",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Key with complex value",
|
||||
tag: "json:user_name,omitempty;validate:required,min=3",
|
||||
key: "json",
|
||||
expected: "user_name,omitempty",
|
||||
},
|
||||
{
|
||||
name: "Multiple semicolons",
|
||||
tag: "json:name;;validate:required",
|
||||
key: "validate",
|
||||
expected: "required",
|
||||
},
|
||||
{
|
||||
name: "BUN Tag",
|
||||
tag: "rel:has-many,join:rid_hub=rid_hub_child",
|
||||
key: "join",
|
||||
expected: "rid_hub=rid_hub_child",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := common.ExtractTagValue(tt.tag, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
@@ -0,0 +1,527 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockSelectQuery implements common.SelectQuery for testing (integration version)
|
||||
type mockSelectQuery struct {
|
||||
operations []string
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Model")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Table:"+table)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
for _, col := range columns {
|
||||
m.operations = append(m.operations, "Column:"+col)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Where:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereOr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereIn:"+column)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Order:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Limit")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Offset")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Join:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Group")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Having:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Preload:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||
// Apply the preload modifiers
|
||||
for _, fn := range apply {
|
||||
fn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
m.operations = append(m.operations, "Scan")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||
m.operations = append(m.operations, "ScanModel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
m.operations = append(m.operations, "Count")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
m.operations = append(m.operations, "Exists")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetModel() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestXFilesRecursivePreload is an integration test that validates the XFiles
|
||||
// recursive preload functionality using real test data files.
|
||||
//
|
||||
// This test ensures:
|
||||
// 1. XFiles request JSON is correctly parsed into PreloadOptions
|
||||
// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM)
|
||||
// 3. Parent WHERE clauses don't leak to child levels
|
||||
// 4. Child relations (like DEF) are extended to all recursive levels
|
||||
// 5. Hierarchical data structure matches expected output
|
||||
func TestXFilesRecursivePreload(t *testing.T) {
|
||||
// Load the XFiles request configuration
|
||||
requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json")
|
||||
requestData, err := os.ReadFile(requestPath)
|
||||
require.NoError(t, err, "Failed to read xfiles.request.json")
|
||||
|
||||
var xfileConfig XFiles
|
||||
err = json.Unmarshal(requestData, &xfileConfig)
|
||||
require.NoError(t, err, "Failed to parse xfiles.request.json")
|
||||
|
||||
// Create handler and parse XFiles into PreloadOptions
|
||||
handler := &Handler{}
|
||||
options := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Preload: []common.PreloadOption{},
|
||||
},
|
||||
}
|
||||
|
||||
// Process the XFiles configuration - start with the root table
|
||||
handler.processXFilesRelations(&xfileConfig, options, "")
|
||||
|
||||
// Verify that preload options were created
|
||||
require.NotEmpty(t, options.Preload, "Expected preload options to be created")
|
||||
|
||||
// Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
|
||||
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// RelatedKey should be the parent relationship key (MTL -> MAL)
|
||||
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
|
||||
"Recursive preload should preserve original RelatedKey for parent relationship")
|
||||
|
||||
// RecursiveChildKey should be set from the recursive child config
|
||||
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
|
||||
"Recursive preload should have RecursiveChildKey set from recursive child config")
|
||||
|
||||
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
|
||||
})
|
||||
|
||||
// Test 2: Verify mastertaskitem has WHERE clause for filtering root items
|
||||
t.Run("RootLevelHasWhereClause", func(t *testing.T) {
|
||||
var rootPreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL" {
|
||||
rootPreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
|
||||
assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
|
||||
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
|
||||
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
|
||||
})
|
||||
|
||||
// Test 3: Verify actiondefinition relation exists for mastertaskitem
|
||||
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||
var defPreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL.DEF" {
|
||||
defPreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem")
|
||||
assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey,
|
||||
"actiondefinition preload should have ForeignKey set")
|
||||
})
|
||||
|
||||
// Test 4: Verify relation name generation with mock query
|
||||
t.Run("RelationNameGeneration", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
found := false
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// Create mock query to track operations
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
|
||||
// Apply the recursive preload
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Verify the correct FK-based relation name was generated
|
||||
foundCorrectRelation := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
// Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundCorrectRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundCorrectRelation,
|
||||
"Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
|
||||
mock.operations)
|
||||
})
|
||||
|
||||
// Test 5: Verify WHERE clause is cleared for recursive levels
|
||||
t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
found := false
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// The root level has a WHERE clause (rid_parentmastertaskitem is null)
|
||||
// But when we apply recursion, it should be cleared
|
||||
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// After the first level, WHERE clauses should not be reapplied
|
||||
// We check that the recursive relation was created (which means WHERE was cleared internally)
|
||||
foundRecursiveRelation := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundRecursiveRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundRecursiveRelation,
|
||||
"Recursive relation should be created (WHERE clause should be cleared internally)")
|
||||
})
|
||||
|
||||
// Test 6: Verify child relations are extended to recursive levels
|
||||
t.Run("ChildRelationsExtended", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
foundRecursive := false
|
||||
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
foundRecursive = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// actiondefinition should be extended to the recursive level
|
||||
// Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
|
||||
foundExtendedDEF := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||
foundExtendedDEF = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundExtendedDEF,
|
||||
"Expected actiondefinition relation to be extended to recursive level. Operations: %v",
|
||||
mock.operations)
|
||||
})
|
||||
}
|
||||
|
||||
// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8
|
||||
func TestXFilesRecursivePreloadDepth(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
|
||||
t.Run("Depth7CreatesLevel8", func(t *testing.T) {
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth8 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth8 = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7")
|
||||
})
|
||||
|
||||
t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) {
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth9 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth9 = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)")
|
||||
})
|
||||
}
|
||||
|
||||
// TestXFilesResponseStructure validates the actual structure of the response
|
||||
// This test can be expanded when we have a full database integration test environment
|
||||
func TestXFilesResponseStructure(t *testing.T) {
|
||||
// Load the expected correct response
|
||||
correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json")
|
||||
correctData, err := os.ReadFile(correctResponsePath)
|
||||
require.NoError(t, err, "Failed to read xfiles.response.correct.json")
|
||||
|
||||
var correctResponse []map[string]interface{}
|
||||
err = json.Unmarshal(correctData, &correctResponse)
|
||||
require.NoError(t, err, "Failed to parse xfiles.response.correct.json")
|
||||
|
||||
// Test 1: Verify root level has exactly 1 masterprocess
|
||||
t.Run("RootLevelHasOneItem", func(t *testing.T) {
|
||||
assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record")
|
||||
})
|
||||
|
||||
// Test 2: Verify the root item has MTL relation
|
||||
t.Run("RootHasMTLRelation", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, exists := rootItem["MTL"]
|
||||
assert.True(t, exists, "Root item should have MTL relation")
|
||||
assert.NotNil(t, mtl, "MTL relation should not be null")
|
||||
})
|
||||
|
||||
// Test 3: Verify MTL has MAL items
|
||||
t.Run("MTLHasMALItems", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, exists := firstMTL["MAL"]
|
||||
assert.True(t, exists, "MTL item should have MAL relation")
|
||||
assert.NotNil(t, mal, "MAL relation should not be null")
|
||||
})
|
||||
|
||||
// Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive)
|
||||
t.Run("MALHasRecursiveRelation", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
require.NotEmpty(t, mal, "MAL should have items")
|
||||
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
// The key assertion: check for FK-based relation name
|
||||
recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"]
|
||||
assert.True(t, exists,
|
||||
"MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)")
|
||||
|
||||
// It can be null or an array, depending on whether this item has children
|
||||
if recursiveRelation != nil {
|
||||
_, isArray := recursiveRelation.([]interface{})
|
||||
assert.True(t, isArray,
|
||||
"MAL_RID_PARENTMASTERTASKITEM should be an array when not null")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 5: Verify "Receive COB Document for" appears as a child, not at root
|
||||
t.Run("ChildItemsAreNested", func(t *testing.T) {
|
||||
// This test verifies that "Receive COB Document for" doesn't appear
|
||||
// multiple times at the wrong level, but is properly nested
|
||||
|
||||
// Count how many times we find this description at the MAL level (should be 0 or 1)
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
|
||||
// Count root-level MAL items (before the fix, there were 12; should be 1)
|
||||
assert.Len(t, mal, 1,
|
||||
"MAL should have exactly 1 root-level item (before fix: 12 duplicates)")
|
||||
|
||||
// Verify the root item has a description
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
description, exists := firstMAL["description"]
|
||||
assert.True(t, exists, "MAL item should have a description")
|
||||
assert.Equal(t, "Capture COB Information", description,
|
||||
"Root MAL item should be 'Capture COB Information'")
|
||||
})
|
||||
|
||||
// Test 6: Verify DEF relation exists at MAL level
|
||||
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
require.NotEmpty(t, mal, "MAL should have items")
|
||||
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
// Verify DEF relation exists (child relation extension)
|
||||
def, exists := firstMAL["DEF"]
|
||||
assert.True(t, exists, "MAL item should have DEF relation")
|
||||
|
||||
// DEF can be null or an object
|
||||
if def != nil {
|
||||
_, isMap := def.(map[string]interface{})
|
||||
assert.True(t, isMap, "DEF should be an object when not null")
|
||||
}
|
||||
})
|
||||
}
|
||||
527
pkg/security/OAUTH2.md
Normal file
527
pkg/security/OAUTH2.md
Normal file
@@ -0,0 +1,527 @@
|
||||
# OAuth2 Authentication Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
|
||||
|
||||
## Features
|
||||
|
||||
- **Universal OAuth2 Support**: Works with any OAuth2 provider
|
||||
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
|
||||
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
|
||||
- **Custom Providers**: Easy configuration for any OAuth2 service
|
||||
- **Session Management**: Database-backed session storage
|
||||
- **Token Refresh**: Automatic token refresh support
|
||||
- **State Validation**: Built-in CSRF protection
|
||||
- **User Auto-Creation**: Automatically creates users on first login
|
||||
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Database Setup
|
||||
|
||||
```sql
|
||||
-- Run the schema from database_schema.sql
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
-- OAuth2 stored procedures (7 functions)
|
||||
-- See database_schema.sql for full implementation
|
||||
```
|
||||
|
||||
### 2. Google OAuth2
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
// Create authenticator
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"your-google-client-id",
|
||||
"your-google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Login route - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback route - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. GitHub OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Same routes pattern as Google
|
||||
router.HandleFunc("/auth/github/login", ...)
|
||||
router.HandleFunc("/auth/github/callback", ...)
|
||||
```
|
||||
|
||||
### 4. Microsoft OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewMicrosoftAuthenticator(
|
||||
"your-microsoft-client-id",
|
||||
"your-microsoft-client-secret",
|
||||
"http://localhost:8080/auth/microsoft/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Facebook OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewFacebookAuthenticator(
|
||||
"your-facebook-client-id",
|
||||
"your-facebook-client-secret",
|
||||
"http://localhost:8080/auth/facebook/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
## Custom OAuth2 Provider
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
DB: db,
|
||||
ProviderName: "custom",
|
||||
|
||||
// Optional: Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
return &security.UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Protected Routes
|
||||
|
||||
```go
|
||||
// Create security provider
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
|
||||
// Apply middleware to protected routes
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(security.NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
```
|
||||
|
||||
## Token Refresh
|
||||
|
||||
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
|
||||
- Store the refresh token securely on the client side
|
||||
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
|
||||
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
|
||||
|
||||
## Logout
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
|
||||
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
```
|
||||
|
||||
## Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single DatabaseAuthenticator with ALL OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders() // ["google", "github"]
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// Use same authenticator for protected routes - works for ALL providers
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### OAuth2Config Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| ClientID | string | OAuth2 client ID from provider |
|
||||
| ClientSecret | string | OAuth2 client secret |
|
||||
| RedirectURL | string | Callback URL registered with provider |
|
||||
| Scopes | []string | OAuth2 scopes to request |
|
||||
| AuthURL | string | Provider's authorization endpoint |
|
||||
| TokenURL | string | Provider's token endpoint |
|
||||
| UserInfoURL | string | Provider's user info endpoint |
|
||||
| DB | *sql.DB | Database connection for sessions |
|
||||
| UserInfoParser | func | Custom parser for user info (optional) |
|
||||
| StateValidator | func | Custom state validator (optional) |
|
||||
| ProviderName | string | Provider name for logging (optional) |
|
||||
|
||||
## User Info Parsing
|
||||
|
||||
The default parser extracts these standard fields:
|
||||
- `sub` → RemoteID
|
||||
- `email` → Email, UserName
|
||||
- `name` → UserName
|
||||
- `login` → UserName (GitHub)
|
||||
|
||||
Custom parser example:
|
||||
|
||||
```go
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
// Extract custom fields
|
||||
ctx := &security.UserContext{
|
||||
UserName: userInfo["preferred_username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["sub"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo, // Store all claims
|
||||
}
|
||||
|
||||
// Add custom roles based on provider data
|
||||
if groups, ok := userInfo["groups"].([]interface{}); ok {
|
||||
for _, g := range groups {
|
||||
ctx.Roles = append(ctx.Roles, g.(string))
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
```go
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Secure: true, // Only send over HTTPS
|
||||
HttpOnly: true, // Prevent XSS access
|
||||
SameSite: http.SameSiteLaxMode, // CSRF protection
|
||||
})
|
||||
```
|
||||
|
||||
2. **Store secrets securely**
|
||||
```go
|
||||
clientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
```
|
||||
|
||||
3. **Validate redirect URLs**
|
||||
- Only register trusted redirect URLs with OAuth2 providers
|
||||
- Never accept redirect URL from request parameters
|
||||
|
||||
5. **Session expiration**
|
||||
- OAuth2 sessions automatically expire based on token expiry
|
||||
- Clean up expired sessions periodically:
|
||||
```sql
|
||||
DELETE FROM user_sessions WHERE expires_at < NOW();
|
||||
```
|
||||
|
||||
4. **State parameter**
|
||||
- Automatically generated with cryptographic randomness
|
||||
- One-time use and expires after 10 minutes
|
||||
- Prevents CSRF attacks
|
||||
|
||||
## Implementation Details
|
||||
|
||||
All database operations use stored procedures for consistency and security:
|
||||
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
|
||||
- `resolvespec_oauth_createsession` - Create OAuth2 session
|
||||
- `resolvespec_oauth_getsession` - Validate and retrieve session
|
||||
- `resolvespec_oauth_deletesession` - Logout/delete session
|
||||
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser` - Get user data by ID
|
||||
|
||||
## Provider Setup Guides
|
||||
|
||||
### Google
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select existing
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
|
||||
6. Copy Client ID and Client Secret
|
||||
|
||||
### GitHub
|
||||
|
||||
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
|
||||
2. Click "New OAuth App"
|
||||
3. Set Homepage URL: `http://localhost:8080`
|
||||
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
|
||||
5. Copy Client ID and Client Secret
|
||||
|
||||
### Microsoft
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Register new application in Azure AD
|
||||
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
|
||||
4. Create client secret
|
||||
5. Copy Application (client) ID and secret value
|
||||
|
||||
### Facebook
|
||||
|
||||
1. Go to [Facebook Developers](https://developers.facebook.com/)
|
||||
2. Create new app
|
||||
3. Add Facebook Login product
|
||||
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
|
||||
5. Copy App ID and App Secret
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "redirect_uri_mismatch" error
|
||||
- Ensure the redirect URL in code matches exactly with provider configuration
|
||||
- Include protocol (http/https), domain, port, and path
|
||||
|
||||
### "invalid_client" error
|
||||
- Verify Client ID and Client Secret are correct
|
||||
- Check if credentials are for the correct environment (dev/prod)
|
||||
|
||||
### "invalid_grant" error during token exchange
|
||||
- State parameter validation failed
|
||||
- Token might have expired
|
||||
- Check server time synchronization
|
||||
|
||||
### User not created after successful OAuth2 login
|
||||
- Check database constraints (username/email unique)
|
||||
- Verify UserInfoParser is extracting required fields
|
||||
- Check database logs for constraint violations
|
||||
|
||||
## Testing
|
||||
|
||||
```go
|
||||
func TestOAuth2Flow(t *testing.T) {
|
||||
// Mock database
|
||||
db, mock, _ := sqlmock.New()
|
||||
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Test state generation
|
||||
state, err := oauth2Auth.GenerateState()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
|
||||
// Test auth URL generation
|
||||
authURL := oauth2Auth.GetAuthURL(state)
|
||||
assert.Contains(t, authURL, "accounts.google.com")
|
||||
assert.Contains(t, authURL, state)
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### DatabaseAuthenticator with OAuth2
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
|
||||
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
|
||||
| OAuth2GenerateState() | Generates random state for CSRF protection |
|
||||
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
|
||||
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
|
||||
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
|
||||
| Login(ctx, req) | Standard username/password login |
|
||||
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
|
||||
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
|
||||
|
||||
### Pre-configured Constructors
|
||||
|
||||
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
|
||||
|
||||
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
|
||||
|
||||
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
|
||||
|
||||
## Examples
|
||||
|
||||
Complete working examples available in `oauth2_examples.go`:
|
||||
- Basic Google OAuth2
|
||||
- GitHub OAuth2
|
||||
- Custom provider
|
||||
- Multi-provider setup
|
||||
- Token refresh
|
||||
- Logout flow
|
||||
- Complete integration with security middleware
|
||||
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# OAuth2 Refresh Token - Quick Reference
|
||||
|
||||
## Quick Setup (3 Steps)
|
||||
|
||||
### 1. Initialize Authenticator
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"client-id",
|
||||
"client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. OAuth2 Login Flow
|
||||
```go
|
||||
// Login - Redirect to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback - Store tokens
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, _ := auth.OAuth2HandleCallback(
|
||||
r.Context(),
|
||||
"google",
|
||||
r.URL.Query().Get("code"),
|
||||
r.URL.Query().Get("state"),
|
||||
)
|
||||
|
||||
// Save refresh_token on client
|
||||
// loginResp.RefreshToken - Store this securely!
|
||||
// loginResp.Token - Session token for API calls
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Refresh Endpoint
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Multi-Provider Example
|
||||
|
||||
```go
|
||||
// Configure multiple providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google",
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "github",
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
})
|
||||
|
||||
// Refresh with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Client-Side JavaScript
|
||||
|
||||
```javascript
|
||||
// Automatic token refresh on 401
|
||||
async function apiCall(url) {
|
||||
let response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
// Token expired - refresh it
|
||||
if (response.status === 401) {
|
||||
await refreshToken();
|
||||
|
||||
// Retry request with new token
|
||||
response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function refreshToken() {
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: localStorage.getItem('refresh_token'),
|
||||
provider: localStorage.getItem('provider')
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Methods
|
||||
|
||||
| Method | Parameters | Returns |
|
||||
|--------|-----------|---------|
|
||||
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
|
||||
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
|
||||
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
|
||||
| `OAuth2GenerateState` | none | `string, error` |
|
||||
| `OAuth2GetProviders` | none | `[]string` |
|
||||
|
||||
---
|
||||
|
||||
## LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token for API calls
|
||||
RefreshToken string // Refresh token (store securely)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until token expires
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser(user_id)` - Get user data
|
||||
|
||||
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
|
||||
|
||||
---
|
||||
|
||||
## Common Errors
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
|
||||
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
|
||||
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
|
||||
|
||||
---
|
||||
|
||||
## Security Checklist
|
||||
|
||||
- [ ] Use HTTPS for all OAuth2 endpoints
|
||||
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
|
||||
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
|
||||
- [ ] Implement rate limiting on refresh endpoint
|
||||
- [ ] Log refresh attempts for audit
|
||||
- [ ] Rotate tokens on refresh
|
||||
- [ ] Revoke old sessions after successful refresh
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# 1. Login and get refresh token
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow OAuth2 flow, get refresh_token from callback response
|
||||
|
||||
# 2. Refresh token
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
|
||||
|
||||
# 3. Use new token
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pre-configured Providers
|
||||
|
||||
```go
|
||||
// Google
|
||||
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// GitHub
|
||||
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Microsoft
|
||||
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Facebook
|
||||
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// All providers at once
|
||||
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
|
||||
"google": {...},
|
||||
"github": {...},
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Provider-Specific Notes
|
||||
|
||||
### Google
|
||||
- Add `access_type=offline` to get refresh token
|
||||
- Add `prompt=consent` to force consent screen
|
||||
```go
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
### GitHub
|
||||
- Refresh tokens not always provided
|
||||
- May need to request `offline_access` scope
|
||||
|
||||
### Microsoft
|
||||
- Use `offline_access` scope for refresh token
|
||||
|
||||
### Facebook
|
||||
- Tokens expire after 60 days by default
|
||||
- Check app settings for token expiration policy
|
||||
|
||||
---
|
||||
|
||||
## Complete Example
|
||||
|
||||
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
|
||||
|
||||
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.
|
||||
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,495 @@
|
||||
# OAuth2 Refresh Token Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
|
||||
|
||||
## Implementation Status: ✅ COMPLETE
|
||||
|
||||
### Components Implemented
|
||||
|
||||
1. **✅ Database Schema** - Tables and stored procedures
|
||||
2. **✅ Go Methods** - OAuth2RefreshToken implementation
|
||||
3. **✅ Thread Safety** - Mutex protection for provider map
|
||||
4. **✅ Examples** - Working code examples
|
||||
5. **✅ Documentation** - Complete API reference
|
||||
|
||||
---
|
||||
|
||||
## 1. Database Schema
|
||||
|
||||
### Tables Modified
|
||||
|
||||
```sql
|
||||
-- user_sessions table with OAuth2 token fields
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT, -- OAuth2 access token
|
||||
refresh_token TEXT, -- OAuth2 refresh token
|
||||
token_type VARCHAR(50), -- "Bearer", etc.
|
||||
auth_provider VARCHAR(50) -- "google", "github", etc.
|
||||
);
|
||||
```
|
||||
|
||||
### Stored Procedures
|
||||
|
||||
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
|
||||
- Gets OAuth2 session data by refresh token
|
||||
- Returns: `{user_id, access_token, token_type, expiry}`
|
||||
- Location: `database_schema.sql:714`
|
||||
|
||||
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
|
||||
- Updates session with new tokens after refresh
|
||||
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
|
||||
- Location: `database_schema.sql:752`
|
||||
|
||||
**`resolvespec_oauth_getuser(p_user_id)`**
|
||||
- Gets user data by ID for building UserContext
|
||||
- Location: `database_schema.sql:791`
|
||||
|
||||
---
|
||||
|
||||
## 2. Go Implementation
|
||||
|
||||
### Method Signature
|
||||
|
||||
```go
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
|
||||
ctx context.Context,
|
||||
refreshToken string,
|
||||
providerName string,
|
||||
) (*LoginResponse, error)
|
||||
```
|
||||
|
||||
**Location:** `pkg/security/oauth2_methods.go:375`
|
||||
|
||||
### Implementation Flow
|
||||
|
||||
```
|
||||
1. Validate provider exists
|
||||
├─ getOAuth2Provider(providerName) with RLock
|
||||
└─ Return error if provider not configured
|
||||
|
||||
2. Get session from database
|
||||
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
|
||||
└─ Parse session data {user_id, access_token, token_type, expiry}
|
||||
|
||||
3. Refresh token with OAuth2 provider
|
||||
├─ Create oauth2.Token from stored data
|
||||
├─ Use provider.config.TokenSource(ctx, oldToken)
|
||||
└─ Call tokenSource.Token() to get new token
|
||||
|
||||
4. Generate new session token
|
||||
└─ Use OAuth2GenerateState() for secure random token
|
||||
|
||||
5. Update database
|
||||
├─ Call resolvespec_oauth_updaterefreshtoken()
|
||||
└─ Store new session_token, access_token, refresh_token
|
||||
|
||||
6. Get user data
|
||||
├─ Call resolvespec_oauth_getuser(user_id)
|
||||
└─ Build UserContext
|
||||
|
||||
7. Return LoginResponse
|
||||
└─ {Token, RefreshToken, User, ExpiresIn}
|
||||
```
|
||||
|
||||
### Thread Safety
|
||||
|
||||
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
|
||||
|
||||
```go
|
||||
type DatabaseAuthenticator struct {
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
|
||||
}
|
||||
|
||||
// Read operations use RLock
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
// ... access map
|
||||
}
|
||||
|
||||
// Write operations use Lock
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
defer a.oauth2ProvidersMutex.Unlock()
|
||||
// ... modify map
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Usage Examples
|
||||
|
||||
### Single Provider (Google)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create Google OAuth2 authenticator
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Token refresh endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token (provider name defaults to "google")
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single authenticator with multiple OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Refresh endpoint with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh with specific provider
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
### Client-Side Usage
|
||||
|
||||
```javascript
|
||||
// JavaScript client example
|
||||
async function refreshAccessToken() {
|
||||
const refreshToken = localStorage.getItem('refresh_token');
|
||||
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
|
||||
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: refreshToken,
|
||||
provider: provider
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
|
||||
// Store new tokens
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
|
||||
console.log('Token refreshed successfully');
|
||||
return data.token;
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
|
||||
// Automatically refresh token when API returns 401
|
||||
async function apiCall(endpoint) {
|
||||
let response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
// Token expired - try refresh
|
||||
const newToken = await refreshAccessToken();
|
||||
|
||||
// Retry with new token
|
||||
response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + newToken
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. API Reference
|
||||
|
||||
### DatabaseAuthenticator Methods
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
|
||||
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
|
||||
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
|
||||
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
|
||||
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
|
||||
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
|
||||
|
||||
### LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token
|
||||
RefreshToken string // New refresh token (may be same as input)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until expiration
|
||||
}
|
||||
|
||||
type UserContext struct {
|
||||
UserID int // Database user ID
|
||||
UserName string // Username
|
||||
Email string // Email address
|
||||
UserLevel int // Permission level
|
||||
SessionID string // Session token
|
||||
RemoteID string // OAuth2 provider user ID
|
||||
Roles []string // User roles
|
||||
Claims map[string]any // Additional claims
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Important Notes
|
||||
|
||||
### Provider Configuration
|
||||
|
||||
**For Google:** Add `access_type=offline` to get refresh token on first login:
|
||||
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
|
||||
// When generating auth URL, add access_type parameter
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
|
||||
|
||||
### Token Storage
|
||||
|
||||
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
|
||||
- Never log refresh tokens
|
||||
- Refresh tokens are long-lived (days/months depending on provider)
|
||||
- Access tokens are short-lived (minutes/hours)
|
||||
|
||||
### Error Handling
|
||||
|
||||
Common errors:
|
||||
- `"invalid or expired refresh token"` - Token expired or revoked
|
||||
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
|
||||
- `"failed to refresh token with provider"` - Provider rejected refresh request
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS** for token transmission
|
||||
2. **Store refresh tokens securely** on client
|
||||
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
|
||||
4. **Implement token rotation** - issue new refresh token on each refresh
|
||||
5. **Revoke old tokens** after successful refresh
|
||||
6. **Rate limit** refresh endpoints
|
||||
7. **Log refresh attempts** for audit trail
|
||||
|
||||
---
|
||||
|
||||
## 6. Testing
|
||||
|
||||
### Manual Test Flow
|
||||
|
||||
1. **Initial Login:**
|
||||
```bash
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow redirect to Google
|
||||
# Returns to callback with LoginResponse containing refresh_token
|
||||
```
|
||||
|
||||
2. **Wait for Token Expiry (or manually expire in DB)**
|
||||
|
||||
3. **Refresh Token:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"provider": "google"
|
||||
}'
|
||||
|
||||
# Response:
|
||||
{
|
||||
"token": "sess_abc123...",
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"user": {
|
||||
"user_id": 1,
|
||||
"user_name": "john_doe",
|
||||
"email": "john@example.com",
|
||||
"session_id": "sess_abc123..."
|
||||
},
|
||||
"expires_in": 3600
|
||||
}
|
||||
```
|
||||
|
||||
4. **Use New Token:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
### Database Verification
|
||||
|
||||
```sql
|
||||
-- Check session with refresh token
|
||||
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
|
||||
FROM user_sessions
|
||||
WHERE refresh_token = 'ya29.a0AfH6SMB...';
|
||||
|
||||
-- Verify token was updated after refresh
|
||||
SELECT session_token, access_token, refresh_token,
|
||||
expires_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = 1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Troubleshooting
|
||||
|
||||
### "Refresh token not found or expired"
|
||||
|
||||
**Cause:** Refresh token doesn't exist in database or session expired
|
||||
|
||||
**Solution:**
|
||||
- Check if initial OAuth2 login stored refresh token
|
||||
- Verify provider returns refresh token (some require `access_type=offline`)
|
||||
- Check session hasn't been deleted from database
|
||||
|
||||
### "Failed to refresh token with provider"
|
||||
|
||||
**Cause:** OAuth2 provider rejected the refresh request
|
||||
|
||||
**Possible reasons:**
|
||||
- Refresh token was revoked by user
|
||||
- OAuth2 app credentials changed
|
||||
- Network connectivity issues
|
||||
- Provider rate limiting
|
||||
|
||||
**Solution:**
|
||||
- Re-authenticate user (full OAuth2 flow)
|
||||
- Check provider dashboard for app status
|
||||
- Verify client credentials are correct
|
||||
|
||||
### "OAuth2 provider 'xxx' not found"
|
||||
|
||||
**Cause:** Provider not registered with `WithOAuth2()`
|
||||
|
||||
**Solution:**
|
||||
```go
|
||||
// Make sure provider is configured
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google", // This name must match refresh call
|
||||
// ... other config
|
||||
})
|
||||
|
||||
// Then use same name in refresh
|
||||
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Complete Working Example
|
||||
|
||||
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
OAuth2 refresh token functionality is **production-ready** with:
|
||||
|
||||
- ✅ Complete database schema with stored procedures
|
||||
- ✅ Thread-safe Go implementation with mutex protection
|
||||
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
|
||||
- ✅ Comprehensive error handling
|
||||
- ✅ Working code examples
|
||||
- ✅ Full API documentation
|
||||
- ✅ Security best practices implemented
|
||||
|
||||
**No additional implementation needed - feature is complete and functional.**
|
||||
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Passkey Authentication Quick Reference
|
||||
|
||||
## Overview
|
||||
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
|
||||
|
||||
## Setup
|
||||
|
||||
### Database Schema
|
||||
Run the passkey SQL schema (in database_schema.sql):
|
||||
- Creates `user_passkey_credentials` table
|
||||
- Adds stored procedures for passkey operations
|
||||
|
||||
### Go Code
|
||||
```go
|
||||
// Create passkey provider
|
||||
passkeyProvider := security.NewDatabasePasskeyProvider(db,
|
||||
security.DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
Timeout: 60000,
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
auth := security.NewDatabaseAuthenticatorWithOptions(db,
|
||||
security.DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Or add passkey to existing authenticator
|
||||
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
```
|
||||
|
||||
## Registration Flow
|
||||
|
||||
### Backend - Step 1: Begin Registration
|
||||
```go
|
||||
options, err := auth.BeginPasskeyRegistration(ctx,
|
||||
security.PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Create Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
|
||||
// Create credential
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send credential back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Registration
|
||||
```go
|
||||
credential, err := auth.CompletePasskeyRegistration(ctx,
|
||||
security.PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
```
|
||||
|
||||
## Authentication Flow
|
||||
|
||||
### Backend - Step 1: Begin Authentication
|
||||
```go
|
||||
options, err := auth.BeginPasskeyAuthentication(ctx,
|
||||
security.PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional for resident key
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Get Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
|
||||
// Get credential
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send assertion back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Authentication
|
||||
```go
|
||||
loginResponse, err := auth.LoginWithPasskey(ctx,
|
||||
security.PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
// Returns session token and user info
|
||||
```
|
||||
|
||||
## Credential Management
|
||||
|
||||
### List Credentials
|
||||
```go
|
||||
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
|
||||
```
|
||||
|
||||
### Update Credential Name
|
||||
```go
|
||||
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
|
||||
```
|
||||
|
||||
### Delete Credential
|
||||
```go
|
||||
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
|
||||
```
|
||||
|
||||
## HTTP Endpoints Example
|
||||
|
||||
### POST /api/passkey/register/begin
|
||||
Request: `{user_id, username, display_name}`
|
||||
Response: PasskeyRegistrationOptions
|
||||
|
||||
### POST /api/passkey/register/complete
|
||||
Request: `{user_id, response, credential_name}`
|
||||
Response: PasskeyCredential
|
||||
|
||||
### POST /api/passkey/login/begin
|
||||
Request: `{username}` (optional)
|
||||
Response: PasskeyAuthenticationOptions
|
||||
|
||||
### POST /api/passkey/login/complete
|
||||
Request: `{response}`
|
||||
Response: LoginResponse with session token
|
||||
|
||||
### GET /api/passkey/credentials
|
||||
Response: Array of PasskeyCredential
|
||||
|
||||
### DELETE /api/passkey/credentials/{id}
|
||||
Request: `{credential_id}`
|
||||
Response: 204 No Content
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_passkey_store_credential` - Store new credential
|
||||
- `resolvespec_passkey_get_credential` - Get credential by ID
|
||||
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
|
||||
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
|
||||
- `resolvespec_passkey_delete_credential` - Delete credential
|
||||
- `resolvespec_passkey_update_name` - Update credential name
|
||||
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
|
||||
|
||||
## Security Features
|
||||
|
||||
- **Clone Detection**: Sign counter validation detects credential cloning
|
||||
- **Attestation Support**: Stores attestation type (none, indirect, direct)
|
||||
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
|
||||
- **Backup State**: Tracks if credential is backed up/synced
|
||||
- **User Verification**: Supports preferred/required user verification
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
|
||||
|
||||
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
|
||||
|
||||
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
|
||||
|
||||
4. **Browser Support**: Check browser compatibility for WebAuthn API.
|
||||
|
||||
5. **Relying Party ID**: Must match your domain exactly.
|
||||
|
||||
## Client-Side Helper Functions
|
||||
|
||||
```javascript
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests: `go test -v ./pkg/security -run Passkey`
|
||||
|
||||
All passkey functionality includes comprehensive tests using sqlmock.
|
||||
@@ -7,15 +7,16 @@
|
||||
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
|
||||
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// OR: auth := security.NewHeaderAuthenticator()
|
||||
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
|
||||
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
// Step 2: Combine providers
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
|
||||
// Step 3: Setup and apply middleware
|
||||
securityList := security.SetupSecurityProvider(handler, provider)
|
||||
securityList, _ := security.SetupSecurityProvider(handler, provider)
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```
|
||||
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```go
|
||||
// DatabaseAuthenticator uses these stored procedures:
|
||||
resolvespec_login(jsonb) // Login with credentials
|
||||
resolvespec_register(jsonb) // Register new user
|
||||
resolvespec_logout(jsonb) // Invalidate session
|
||||
resolvespec_session(text, text) // Validate session token
|
||||
resolvespec_session_update(text, jsonb) // Update activity timestamp
|
||||
@@ -502,10 +504,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
|
||||
|
||||
---
|
||||
|
||||
## Login/Logout Endpoints
|
||||
## Login/Logout/Register Endpoints
|
||||
|
||||
```go
|
||||
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
|
||||
// Register
|
||||
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.RegisterRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Check if provider supports registration
|
||||
registrable, ok := securityList.Provider().(security.Registrable)
|
||||
if !ok {
|
||||
http.Error(w, "Registration not supported", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := registrable.Register(r.Context(), req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}).Methods("POST")
|
||||
|
||||
// Login
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.LoginRequest
|
||||
@@ -707,6 +730,7 @@ meta, ok := security.GetUserMeta(ctx)
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
|
||||
| `examples.go` | Working provider implementations to copy |
|
||||
| `setup_example.go` | 6 complete integration examples |
|
||||
| `README.md` | Architecture overview and migration guide |
|
||||
|
||||
@@ -6,6 +6,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
|
||||
- ✅ **Interface-Based** - Type-safe providers instead of callbacks
|
||||
- ✅ **Login/Logout Support** - Built-in authentication lifecycle
|
||||
- ✅ **Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
|
||||
- ✅ **Composable** - Mix and match different providers
|
||||
- ✅ **No Global State** - Each handler has its own security configuration
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
@@ -212,6 +213,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// Note: Requires JWT library installation for token signing/verification
|
||||
```
|
||||
|
||||
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
|
||||
```go
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
// Use in-memory provider (for testing)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
|
||||
|
||||
// Or use database provider (for production)
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
// Requires: users table with totp fields, user_totp_backup_codes table
|
||||
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
|
||||
|
||||
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
// Supports: TOTP codes, backup codes, QR code generation
|
||||
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
|
||||
```
|
||||
|
||||
### Column Security Providers
|
||||
|
||||
**DatabaseColumnSecurityProvider** - Loads rules from database:
|
||||
@@ -334,7 +352,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Two-Factor Authentication (2FA)
|
||||
|
||||
### Overview
|
||||
|
||||
- **Optional per-user** - Enable/disable 2FA individually
|
||||
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
|
||||
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
|
||||
- **Backup codes** - One-time recovery codes with secure hashing
|
||||
- **Clock skew** - Handles time differences between client/server
|
||||
|
||||
### Setup
|
||||
|
||||
```go
|
||||
// 1. Wrap existing authenticator with 2FA support
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
|
||||
// 2. Use as normal authenticator
|
||||
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
### Enable 2FA for User
|
||||
|
||||
```go
|
||||
// 1. Initiate 2FA setup
|
||||
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
|
||||
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
|
||||
|
||||
// 2. User scans QR code with authenticator app
|
||||
// Display secret.QRCodeURL as QR code image
|
||||
|
||||
// 3. User enters verification code from app
|
||||
code := "123456" // From authenticator app
|
||||
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
|
||||
// 2FA is now enabled for this user
|
||||
|
||||
// 4. Store backup codes securely and show to user once
|
||||
// Display: secret.BackupCodes (10 codes)
|
||||
```
|
||||
|
||||
### Login Flow with 2FA
|
||||
|
||||
```go
|
||||
// 1. User provides credentials
|
||||
req := security.LoginRequest{
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(ctx, req)
|
||||
|
||||
// 2. Check if 2FA required
|
||||
if resp.Requires2FA {
|
||||
// Prompt user for 2FA code
|
||||
code := getUserInput() // From authenticator app or backup code
|
||||
|
||||
// 3. Login again with 2FA code
|
||||
req.TwoFactorCode = code
|
||||
resp, err = tfaAuth.Login(ctx, req)
|
||||
|
||||
// 4. Success - token is returned
|
||||
token := resp.Token
|
||||
}
|
||||
```
|
||||
|
||||
### Manage 2FA
|
||||
|
||||
```go
|
||||
// Disable 2FA
|
||||
err := tfaAuth.Disable2FA(userID)
|
||||
|
||||
// Regenerate backup codes
|
||||
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
|
||||
|
||||
// Check status
|
||||
has2FA, err := tfaProvider.Get2FAStatus(userID)
|
||||
```
|
||||
|
||||
### Custom 2FA Storage
|
||||
|
||||
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
|
||||
|
||||
```go
|
||||
// Uses PostgreSQL stored procedures for all operations
|
||||
db := setupDatabase()
|
||||
|
||||
// Run migrations from totp_database_schema.sql
|
||||
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
|
||||
// - Create user_totp_backup_codes table
|
||||
// - Create resolvespec_totp_* stored procedures
|
||||
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
```
|
||||
|
||||
**Option 2: Implement Custom Provider**
|
||||
|
||||
Implement `TwoFactorAuthProvider` for custom storage:
|
||||
|
||||
```go
|
||||
type DBTwoFactorProvider struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
// Store secret and hashed backup codes in database
|
||||
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
|
||||
secret, hashCodes(backupCodes), userID).Error
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var secret string
|
||||
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
|
||||
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```go
|
||||
config := &security.TwoFactorConfig{
|
||||
Algorithm: "SHA256", // SHA1, SHA256, SHA512
|
||||
Digits: 8, // 6 or 8
|
||||
Period: 30, // Seconds per code
|
||||
SkewWindow: 2, // Accept codes ±2 periods
|
||||
}
|
||||
|
||||
totp := security.NewTOTPGenerator(config)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
|
||||
```
|
||||
|
||||
### API Response Structure
|
||||
|
||||
```go
|
||||
// LoginResponse with 2FA
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
|
||||
User *UserContext `json:"user"`
|
||||
}
|
||||
|
||||
// TwoFactorSecret for setup
|
||||
type TwoFactorSecret struct {
|
||||
Secret string `json:"secret"` // Base32 encoded
|
||||
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
|
||||
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
|
||||
}
|
||||
|
||||
// UserContext includes 2FA status
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"`
|
||||
// ... other fields
|
||||
}
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
- **Store secrets encrypted** - Never store TOTP secrets in plain text
|
||||
- **Hash backup codes** - Use SHA-256 before storing
|
||||
- **Rate limit** - Limit 2FA verification attempts
|
||||
- **Require password** - Always verify password before disabling 2FA
|
||||
- **Show backup codes once** - Display only during setup/regeneration
|
||||
- **Log 2FA events** - Track enable/disable/failed attempts
|
||||
- **Mark codes as used** - Backup codes are single-use only
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,33 +7,48 @@ import (
|
||||
|
||||
// UserContext holds authenticated user information
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
|
||||
}
|
||||
|
||||
// LoginRequest contains credentials for login
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// RegisterRequest contains information for new user registration
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
UserLevel int `json:"user_level"`
|
||||
Roles []string `json:"roles"`
|
||||
Claims map[string]any `json:"claims"` // Additional registration data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata
|
||||
}
|
||||
|
||||
// LoginResponse contains the result of a login attempt
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// LogoutRequest contains information for logout
|
||||
@@ -55,6 +70,12 @@ type Authenticator interface {
|
||||
Authenticate(r *http.Request) (*UserContext, error)
|
||||
}
|
||||
|
||||
// Registrable allows providers to support user registration
|
||||
type Registrable interface {
|
||||
// Register creates a new user account
|
||||
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
|
||||
}
|
||||
|
||||
// ColumnSecurityProvider handles column-level security (masking/hiding)
|
||||
type ColumnSecurityProvider interface {
|
||||
// GetColumnSecurity loads column security rules for a user and entity
|
||||
|
||||
615
pkg/security/oauth2_examples.go
Normal file
615
pkg/security/oauth2_examples.go
Normal file
@@ -0,0 +1,615 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Example: OAuth2 Authentication with Google
|
||||
func ExampleOAuth2Google() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticator for Google
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Login endpoint - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback endpoint - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
// Return user info as JSON
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Authentication with GitHub
|
||||
func ExampleOAuth2GitHub() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Custom OAuth2 Provider
|
||||
func ExampleOAuth2Custom() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Custom OAuth2 provider configuration
|
||||
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
ProviderName: "custom-provider",
|
||||
|
||||
// Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
|
||||
// Extract custom fields from your provider
|
||||
return &UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Multi-Provider OAuth2 with Security Integration
|
||||
func ExampleOAuth2MultiProvider() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticators for multiple providers
|
||||
googleAuth := NewGoogleAuthenticator(
|
||||
"google-client-id",
|
||||
"google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
githubAuth := NewGitHubAuthenticator(
|
||||
"github-client-id",
|
||||
"github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create column and row security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google OAuth2 routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := googleAuth.OAuth2GenerateState()
|
||||
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// GitHub OAuth2 routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := githubAuth.OAuth2GenerateState()
|
||||
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Use Google auth for protected routes (or GitHub - both work)
|
||||
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected route with authentication
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 with Token Refresh
|
||||
func ExampleOAuth2TokenRefresh() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Refresh token endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Logout
|
||||
func ExampleOAuth2Logout() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err == nil {
|
||||
token = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
// Get user ID from session
|
||||
userCtx, err := oauth2Auth.Authenticate(r)
|
||||
if err == nil {
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: token,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Clear cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Logged out successfully"))
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Complete OAuth2 Integration with Database Setup
|
||||
func ExampleOAuth2Complete() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create tables (run once)
|
||||
setupOAuth2Tables(db)
|
||||
|
||||
// Create OAuth2 authenticator
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Public routes
|
||||
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
protectedRouter := router.PathPrefix("/").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
func setupOAuth2Tables(db *sql.DB) {
|
||||
// Create tables from database_schema.sql
|
||||
// This is a helper function - in production, use migrations
|
||||
ctx := context.Background()
|
||||
|
||||
// Create users table if not exists
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
|
||||
// Create user_sessions table (used for both regular and OAuth2 sessions)
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
}
|
||||
|
||||
// Example: All OAuth2 Providers at Once
|
||||
func ExampleOAuth2AllProviders() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create authenticator with ALL OAuth2 providers
|
||||
auth := NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "microsoft-client-id",
|
||||
ClientSecret: "microsoft-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "facebook-client-id",
|
||||
ClientSecret: "facebook-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/facebook/callback",
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders()
|
||||
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Microsoft routes
|
||||
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Facebook routes
|
||||
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Create security list for protected routes
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected routes work for ALL OAuth2 providers + regular sessions
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
579
pkg/security/oauth2_methods.go
Normal file
579
pkg/security/oauth2_methods.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth2Config contains configuration for OAuth2 authentication
|
||||
type OAuth2Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURL string
|
||||
Scopes []string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
ProviderName string
|
||||
|
||||
// Optional: Custom user info parser
|
||||
// If not provided, will use standard claims (sub, email, name)
|
||||
UserInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
}
|
||||
|
||||
// OAuth2Provider holds configuration and state for a single OAuth2 provider
|
||||
type OAuth2Provider struct {
|
||||
config *oauth2.Config
|
||||
userInfoURL string
|
||||
userInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
providerName string
|
||||
states map[string]time.Time // state -> expiry time
|
||||
statesMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// WithOAuth2 configures OAuth2 support for the DatabaseAuthenticator
|
||||
// Can be called multiple times to add multiple OAuth2 providers
|
||||
// Returns the same DatabaseAuthenticator instance for method chaining
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) *DatabaseAuthenticator {
|
||||
if cfg.ProviderName == "" {
|
||||
cfg.ProviderName = "oauth2"
|
||||
}
|
||||
|
||||
if cfg.UserInfoParser == nil {
|
||||
cfg.UserInfoParser = defaultOAuth2UserInfoParser
|
||||
}
|
||||
|
||||
provider := &OAuth2Provider{
|
||||
config: &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Scopes: cfg.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.AuthURL,
|
||||
TokenURL: cfg.TokenURL,
|
||||
},
|
||||
},
|
||||
userInfoURL: cfg.UserInfoURL,
|
||||
userInfoParser: cfg.UserInfoParser,
|
||||
providerName: cfg.ProviderName,
|
||||
states: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Initialize providers map if needed
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
if a.oauth2Providers == nil {
|
||||
a.oauth2Providers = make(map[string]*OAuth2Provider)
|
||||
}
|
||||
|
||||
// Register provider
|
||||
a.oauth2Providers[cfg.ProviderName] = provider
|
||||
a.oauth2ProvidersMutex.Unlock()
|
||||
|
||||
// Start state cleanup goroutine for this provider
|
||||
go provider.cleanupStates()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// OAuth2GetAuthURL returns the OAuth2 authorization URL for redirecting users
|
||||
func (a *DatabaseAuthenticator) OAuth2GetAuthURL(providerName, state string) (string, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store state for validation
|
||||
provider.statesMutex.Lock()
|
||||
provider.states[state] = time.Now().Add(10 * time.Minute)
|
||||
provider.statesMutex.Unlock()
|
||||
|
||||
return provider.config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// OAuth2GenerateState generates a random state string for CSRF protection
|
||||
func (a *DatabaseAuthenticator) OAuth2GenerateState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// OAuth2HandleCallback handles the OAuth2 callback and exchanges code for token
|
||||
func (a *DatabaseAuthenticator) OAuth2HandleCallback(ctx context.Context, providerName, code, state string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate state
|
||||
if !provider.validateState(state) {
|
||||
return nil, fmt.Errorf("invalid state parameter")
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := provider.config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Fetch user info
|
||||
client := provider.config.Client(ctx, token)
|
||||
resp, err := client.Get(provider.userInfoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read user info: %w", err)
|
||||
}
|
||||
|
||||
var userInfo map[string]any
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
// Parse user info
|
||||
userCtx, err := provider.userInfoParser(userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
// Get or create user in database
|
||||
userID, err := a.oauth2GetOrCreateUser(ctx, userCtx, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
userCtx.UserID = userID
|
||||
|
||||
// Create session token
|
||||
sessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
if token.Expiry.After(time.Now()) {
|
||||
expiresAt = token.Expiry
|
||||
}
|
||||
|
||||
// Store session in database
|
||||
err = a.oauth2CreateSession(ctx, sessionToken, userCtx.UserID, token, expiresAt, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = sessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
User: userCtx,
|
||||
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OAuth2GetProviders returns list of configured OAuth2 provider names
|
||||
func (a *DatabaseAuthenticator) OAuth2GetProviders() []string {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providers = append(providers, name)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// getOAuth2Provider retrieves a registered OAuth2 provider by name
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(providerName string) (*OAuth2Provider, error) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil, fmt.Errorf("OAuth2 not configured - call WithOAuth2() first")
|
||||
}
|
||||
|
||||
provider, ok := a.oauth2Providers[providerName]
|
||||
if !ok {
|
||||
// Build provider list without calling OAuth2GetProviders to avoid recursion
|
||||
providerNames := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providerNames = append(providerNames, name)
|
||||
}
|
||||
return nil, fmt.Errorf("OAuth2 provider '%s' not found - available providers: %v", providerName, providerNames)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// oauth2GetOrCreateUser finds or creates a user based on OAuth2 info using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userCtx *UserContext, providerName string) (int, error) {
|
||||
userData := map[string]interface{}{
|
||||
"username": userCtx.UserName,
|
||||
"email": userCtx.Email,
|
||||
"remote_id": userCtx.RemoteID,
|
||||
"user_level": userCtx.UserLevel,
|
||||
"roles": userCtx.Roles,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
userJSON, err := json.Marshal(userData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal user data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return 0, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to get or create user")
|
||||
}
|
||||
|
||||
if userID == nil {
|
||||
return 0, fmt.Errorf("user ID not returned")
|
||||
}
|
||||
|
||||
return *userID, nil
|
||||
}
|
||||
|
||||
// oauth2CreateSession creates a new OAuth2 session using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, sessionToken string, userID int, token *oauth2.Token, expiresAt time.Time, providerName string) error {
|
||||
sessionData := map[string]interface{}{
|
||||
"session_token": sessionToken,
|
||||
"user_id": userID,
|
||||
"access_token": token.AccessToken,
|
||||
"refresh_token": token.RefreshToken,
|
||||
"token_type": token.TokenType,
|
||||
"expires_at": expiresAt,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
sessionJSON, err := json.Marshal(sessionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to create session")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateState validates state using in-memory storage
|
||||
func (p *OAuth2Provider) validateState(state string) bool {
|
||||
p.statesMutex.Lock()
|
||||
defer p.statesMutex.Unlock()
|
||||
|
||||
expiry, ok := p.states[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(expiry) {
|
||||
delete(p.states, state)
|
||||
return false
|
||||
}
|
||||
|
||||
delete(p.states, state) // One-time use
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupStates removes expired states periodically
|
||||
func (p *OAuth2Provider) cleanupStates() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
p.statesMutex.Lock()
|
||||
now := time.Now()
|
||||
for state, expiry := range p.states {
|
||||
if now.After(expiry) {
|
||||
delete(p.states, state)
|
||||
}
|
||||
}
|
||||
p.statesMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// defaultOAuth2UserInfoParser parses standard OAuth2 user info claims
|
||||
func defaultOAuth2UserInfoParser(userInfo map[string]any) (*UserContext, error) {
|
||||
ctx := &UserContext{
|
||||
Claims: userInfo,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
// Extract standard claims
|
||||
if sub, ok := userInfo["sub"].(string); ok {
|
||||
ctx.RemoteID = sub
|
||||
}
|
||||
if email, ok := userInfo["email"].(string); ok {
|
||||
ctx.Email = email
|
||||
// Use email as username if name not available
|
||||
ctx.UserName = strings.Split(email, "@")[0]
|
||||
}
|
||||
if name, ok := userInfo["name"].(string); ok {
|
||||
ctx.UserName = name
|
||||
}
|
||||
if login, ok := userInfo["login"].(string); ok {
|
||||
ctx.UserName = login // GitHub uses "login"
|
||||
}
|
||||
|
||||
if ctx.UserName == "" {
|
||||
return nil, fmt.Errorf("could not extract username from user info")
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// OAuth2RefreshToken refreshes an expired OAuth2 access token using the refresh token
|
||||
// Takes the refresh token and returns a new LoginResponse with updated tokens
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshToken, providerName string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get session by refresh token from database
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
|
||||
// Parse session data
|
||||
var session struct {
|
||||
UserID int `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
if err := json.Unmarshal(sessionData, &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse session data: %w", err)
|
||||
}
|
||||
|
||||
// Create oauth2.Token from stored data
|
||||
oldToken := &oauth2.Token{
|
||||
AccessToken: session.AccessToken,
|
||||
TokenType: session.TokenType,
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: session.Expiry,
|
||||
}
|
||||
|
||||
// Use OAuth2 provider to refresh the token
|
||||
tokenSource := provider.config.TokenSource(ctx, oldToken)
|
||||
newToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token with provider: %w", err)
|
||||
}
|
||||
|
||||
// Generate new session token
|
||||
newSessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate new session token: %w", err)
|
||||
}
|
||||
|
||||
// Update session in database with new tokens
|
||||
updateData := map[string]interface{}{
|
||||
"user_id": session.UserID,
|
||||
"old_refresh_token": refreshToken,
|
||||
"new_session_token": newSessionToken,
|
||||
"new_access_token": newToken.AccessToken,
|
||||
"new_refresh_token": newToken.RefreshToken,
|
||||
"expires_at": newToken.Expiry,
|
||||
}
|
||||
|
||||
updateJSON, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal update data: %w", err)
|
||||
}
|
||||
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
}
|
||||
|
||||
if !updateSuccess {
|
||||
if updateErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *updateErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update session")
|
||||
}
|
||||
|
||||
// Get user data
|
||||
var userSuccess bool
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
}
|
||||
|
||||
if !userSuccess {
|
||||
if userErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *userErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user data")
|
||||
}
|
||||
|
||||
// Parse user context
|
||||
var userCtx UserContext
|
||||
if err := json.Unmarshal(userData, &userCtx); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = newSessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: newSessionToken,
|
||||
RefreshToken: newToken.RefreshToken,
|
||||
User: &userCtx,
|
||||
ExpiresIn: int64(time.Until(newToken.Expiry).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pre-configured OAuth2 factory methods
|
||||
|
||||
// NewGoogleAuthenticator creates a DatabaseAuthenticator configured for Google OAuth2
|
||||
func NewGoogleAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
}
|
||||
|
||||
// NewGitHubAuthenticator creates a DatabaseAuthenticator configured for GitHub OAuth2
|
||||
func NewGitHubAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMicrosoftAuthenticator creates a DatabaseAuthenticator configured for Microsoft OAuth2
|
||||
func NewMicrosoftAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
})
|
||||
}
|
||||
|
||||
// NewFacebookAuthenticator creates a DatabaseAuthenticator configured for Facebook OAuth2
|
||||
func NewFacebookAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMultiProviderAuthenticator creates a DatabaseAuthenticator with all major OAuth2 providers configured
|
||||
func NewMultiProviderAuthenticator(db *sql.DB, configs map[string]OAuth2Config) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
|
||||
//nolint:gocritic // OAuth2Config is copied but kept for API simplicity
|
||||
for _, cfg := range configs {
|
||||
auth.WithOAuth2(cfg)
|
||||
}
|
||||
|
||||
return auth
|
||||
}
|
||||
185
pkg/security/passkey.go
Normal file
185
pkg/security/passkey.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PasskeyCredential represents a stored WebAuthn/FIDO2 credential
|
||||
type PasskeyCredential struct {
|
||||
ID string `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
CredentialID []byte `json:"credential_id"` // Raw credential ID from authenticator
|
||||
PublicKey []byte `json:"public_key"` // COSE public key
|
||||
AttestationType string `json:"attestation_type"` // none, indirect, direct
|
||||
AAGUID []byte `json:"aaguid"` // Authenticator AAGUID
|
||||
SignCount uint32 `json:"sign_count"` // Signature counter
|
||||
CloneWarning bool `json:"clone_warning"` // True if cloning detected
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
BackupEligible bool `json:"backup_eligible"` // Credential can be backed up
|
||||
BackupState bool `json:"backup_state"` // Credential is currently backed up
|
||||
Name string `json:"name,omitempty"` // User-friendly name
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
// PasskeyRegistrationOptions contains options for beginning passkey registration
|
||||
type PasskeyRegistrationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
RelyingParty PasskeyRelyingParty `json:"rp"`
|
||||
User PasskeyUser `json:"user"`
|
||||
PubKeyCredParams []PasskeyCredentialParam `json:"pubKeyCredParams"`
|
||||
Timeout int64 `json:"timeout,omitempty"` // Milliseconds
|
||||
ExcludeCredentials []PasskeyCredentialDescriptor `json:"excludeCredentials,omitempty"`
|
||||
AuthenticatorSelection *PasskeyAuthenticatorSelection `json:"authenticatorSelection,omitempty"`
|
||||
Attestation string `json:"attestation,omitempty"` // none, indirect, direct, enterprise
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationOptions contains options for beginning passkey authentication
|
||||
type PasskeyAuthenticationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
Timeout int64 `json:"timeout,omitempty"`
|
||||
RelyingPartyID string `json:"rpId,omitempty"`
|
||||
AllowCredentials []PasskeyCredentialDescriptor `json:"allowCredentials,omitempty"`
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyRelyingParty identifies the relying party
|
||||
type PasskeyRelyingParty struct {
|
||||
ID string `json:"id"` // Domain (e.g., "example.com")
|
||||
Name string `json:"name"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyUser identifies the user
|
||||
type PasskeyUser struct {
|
||||
ID []byte `json:"id"` // User handle (unique, persistent)
|
||||
Name string `json:"name"` // Username
|
||||
DisplayName string `json:"displayName"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyCredentialParam specifies supported public key algorithm
|
||||
type PasskeyCredentialParam struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
Alg int `json:"alg"` // COSE algorithm identifier (e.g., -7 for ES256, -257 for RS256)
|
||||
}
|
||||
|
||||
// PasskeyCredentialDescriptor describes a credential
|
||||
type PasskeyCredentialDescriptor struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
ID []byte `json:"id"` // Credential ID
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorSelection specifies authenticator requirements
|
||||
type PasskeyAuthenticatorSelection struct {
|
||||
AuthenticatorAttachment string `json:"authenticatorAttachment,omitempty"` // platform, cross-platform
|
||||
RequireResidentKey bool `json:"requireResidentKey,omitempty"`
|
||||
ResidentKey string `json:"residentKey,omitempty"` // discouraged, preferred, required
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
}
|
||||
|
||||
// PasskeyRegistrationResponse contains the client's registration response
|
||||
type PasskeyRegistrationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAttestationResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAttestationResponse contains attestation data
|
||||
type PasskeyAuthenticatorAttestationResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AttestationObject []byte `json:"attestationObject"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationResponse contains the client's authentication response
|
||||
type PasskeyAuthenticationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAssertionResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAssertionResponse contains assertion data
|
||||
type PasskeyAuthenticatorAssertionResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AuthenticatorData []byte `json:"authenticatorData"`
|
||||
Signature []byte `json:"signature"`
|
||||
UserHandle []byte `json:"userHandle,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyProvider handles passkey registration and authentication
|
||||
type PasskeyProvider interface {
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error)
|
||||
|
||||
// CompleteRegistration verifies and stores a new passkey credential
|
||||
CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error)
|
||||
|
||||
// BeginAuthentication creates authentication options for passkey login
|
||||
BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error)
|
||||
|
||||
// CompleteAuthentication verifies a passkey assertion and returns the user
|
||||
CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error)
|
||||
|
||||
// GetCredentials returns all passkey credentials for a user
|
||||
GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error)
|
||||
|
||||
// DeleteCredential removes a passkey credential
|
||||
DeleteCredential(ctx context.Context, userID int, credentialID string) error
|
||||
|
||||
// UpdateCredentialName updates the friendly name of a credential
|
||||
UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error
|
||||
}
|
||||
|
||||
// PasskeyLoginRequest contains passkey authentication data
|
||||
type PasskeyLoginRequest struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
}
|
||||
|
||||
// PasskeyRegisterRequest contains passkey registration data
|
||||
type PasskeyRegisterRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
CredentialName string `json:"credential_name,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyBeginRegistrationRequest contains options for starting passkey registration
|
||||
type PasskeyBeginRegistrationRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// PasskeyBeginAuthenticationRequest contains options for starting passkey authentication
|
||||
type PasskeyBeginAuthenticationRequest struct {
|
||||
Username string `json:"username,omitempty"` // Optional for resident key flow
|
||||
}
|
||||
|
||||
// ParsePasskeyRegistrationResponse parses a JSON passkey registration response
|
||||
func ParsePasskeyRegistrationResponse(data []byte) (*PasskeyRegistrationResponse, error) {
|
||||
var response PasskeyRegistrationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// ParsePasskeyAuthenticationResponse parses a JSON passkey authentication response
|
||||
func ParsePasskeyAuthenticationResponse(data []byte) (*PasskeyAuthenticationResponse, error) {
|
||||
var response PasskeyAuthenticationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
432
pkg/security/passkey_examples.go
Normal file
432
pkg/security/passkey_examples.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PasskeyAuthenticationExample demonstrates passkey (WebAuthn/FIDO2) authentication
|
||||
func PasskeyAuthenticationExample() {
|
||||
// Setup database connection
|
||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||
|
||||
// Create passkey provider
|
||||
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com", // Your domain
|
||||
RPName: "Example Application", // Display name
|
||||
RPOrigin: "https://example.com", // Expected origin
|
||||
Timeout: 60000, // 60 seconds
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
// Option 1: Pass during creation
|
||||
_ = NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Option 2: Use WithPasskey method
|
||||
auth := NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// === REGISTRATION FLOW ===
|
||||
|
||||
// Step 1: Begin registration
|
||||
regOptions, _ := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
|
||||
// Send regOptions to client as JSON
|
||||
// Client will call navigator.credentials.create() with these options
|
||||
_ = regOptions
|
||||
|
||||
// Step 2: Complete registration (after client returns credential)
|
||||
// This would come from the client's navigator.credentials.create() response
|
||||
clientResponse := PasskeyRegistrationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAttestationResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AttestationObject: []byte("..."),
|
||||
},
|
||||
Transports: []string{"internal"},
|
||||
}
|
||||
|
||||
credential, _ := auth.CompletePasskeyRegistration(ctx, PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: regOptions.Challenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
|
||||
fmt.Printf("Registered credential: %s\n", credential.ID)
|
||||
|
||||
// === AUTHENTICATION FLOW ===
|
||||
|
||||
// Step 1: Begin authentication
|
||||
authOptions, _ := auth.BeginPasskeyAuthentication(ctx, PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional - omit for resident key flow
|
||||
})
|
||||
|
||||
// Send authOptions to client as JSON
|
||||
// Client will call navigator.credentials.get() with these options
|
||||
_ = authOptions
|
||||
|
||||
// Step 2: Complete authentication (after client returns assertion)
|
||||
// This would come from the client's navigator.credentials.get() response
|
||||
clientAssertion := PasskeyAuthenticationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAssertionResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AuthenticatorData: []byte("..."),
|
||||
Signature: []byte("..."),
|
||||
},
|
||||
}
|
||||
|
||||
loginResponse, _ := auth.LoginWithPasskey(ctx, PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: authOptions.Challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
|
||||
fmt.Printf("Logged in user: %s with token: %s\n",
|
||||
loginResponse.User.UserName, loginResponse.Token)
|
||||
|
||||
// === CREDENTIAL MANAGEMENT ===
|
||||
|
||||
// Get all credentials for a user
|
||||
credentials, _ := auth.GetPasskeyCredentials(ctx, 1)
|
||||
for i := range credentials {
|
||||
fmt.Printf("Credential: %s (created: %s, last used: %s)\n",
|
||||
credentials[i].Name, credentials[i].CreatedAt, credentials[i].LastUsedAt)
|
||||
}
|
||||
|
||||
// Update credential name
|
||||
_ = auth.UpdatePasskeyCredentialName(ctx, 1, credential.ID, "My New iPhone")
|
||||
|
||||
// Delete credential
|
||||
_ = auth.DeletePasskeyCredential(ctx, 1, credential.ID)
|
||||
}
|
||||
|
||||
// PasskeyHTTPHandlersExample shows HTTP handlers for passkey authentication
|
||||
func PasskeyHTTPHandlersExample(auth *DatabaseAuthenticator) {
|
||||
// Store challenges in session/cache in production
|
||||
challenges := make(map[string][]byte)
|
||||
|
||||
// Begin registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyRegistration(r.Context(), PasskeyBeginRegistrationRequest{
|
||||
UserID: req.UserID,
|
||||
Username: req.Username,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-123"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
CredentialName string `json:"credential_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-123"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
credential, err := auth.CompletePasskeyRegistration(r.Context(), PasskeyRegisterRequest{
|
||||
UserID: req.UserID,
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
CredentialName: req.CredentialName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credential)
|
||||
})
|
||||
|
||||
// Begin authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"` // Optional
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyAuthentication(r.Context(), PasskeyBeginAuthenticationRequest{
|
||||
Username: req.Username,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-456"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-456"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
loginResponse, err := auth.LoginWithPasskey(r.Context(), PasskeyLoginRequest{
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": r.RemoteAddr,
|
||||
"user_agent": r.UserAgent(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResponse.Token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(loginResponse)
|
||||
})
|
||||
|
||||
// List credentials endpoint
|
||||
http.HandleFunc("/api/passkey/credentials", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user from authenticated session
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
credentials, err := auth.GetPasskeyCredentials(r.Context(), userCtx.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credentials)
|
||||
})
|
||||
|
||||
// Delete credential endpoint
|
||||
http.HandleFunc("/api/passkey/credentials/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
CredentialID string `json:"credential_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
err = auth.DeletePasskeyCredential(r.Context(), userCtx.UserID, req.CredentialID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// PasskeyClientSideExample shows the client-side JavaScript code needed
|
||||
func PasskeyClientSideExample() string {
|
||||
return `
|
||||
// === CLIENT-SIDE JAVASCRIPT FOR PASSKEY AUTHENTICATION ===
|
||||
|
||||
// Helper function to convert base64 to ArrayBuffer
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
// Helper function to convert ArrayBuffer to base64
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
|
||||
// === REGISTRATION ===
|
||||
|
||||
async function registerPasskey(userId, username, displayName) {
|
||||
// Step 1: Get registration options from server
|
||||
const optionsResponse = await fetch('/api/passkey/register/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ user_id: userId, username, display_name: displayName })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
if (options.excludeCredentials) {
|
||||
options.excludeCredentials = options.excludeCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Create credential using WebAuthn API
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send credential to server
|
||||
const credentialResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
attestationObject: arrayBufferToBase64(credential.response.attestationObject)
|
||||
},
|
||||
transports: credential.response.getTransports ? credential.response.getTransports() : []
|
||||
};
|
||||
|
||||
const completeResponse = await fetch('/api/passkey/register/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
user_id: userId,
|
||||
response: credentialResponse,
|
||||
credential_name: 'My Device'
|
||||
})
|
||||
});
|
||||
|
||||
return await completeResponse.json();
|
||||
}
|
||||
|
||||
// === AUTHENTICATION ===
|
||||
|
||||
async function loginWithPasskey(username) {
|
||||
// Step 1: Get authentication options from server
|
||||
const optionsResponse = await fetch('/api/passkey/login/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ username })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
if (options.allowCredentials) {
|
||||
options.allowCredentials = options.allowCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Get credential using WebAuthn API
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send assertion to server
|
||||
const assertionResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
authenticatorData: arrayBufferToBase64(credential.response.authenticatorData),
|
||||
signature: arrayBufferToBase64(credential.response.signature),
|
||||
userHandle: credential.response.userHandle ? arrayBufferToBase64(credential.response.userHandle) : null
|
||||
}
|
||||
};
|
||||
|
||||
const loginResponse = await fetch('/api/passkey/login/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ response: assertionResponse })
|
||||
});
|
||||
|
||||
return await loginResponse.json();
|
||||
}
|
||||
|
||||
// === USAGE ===
|
||||
|
||||
// Register a new passkey
|
||||
document.getElementById('register-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await registerPasskey(1, 'alice', 'Alice Smith');
|
||||
console.log('Passkey registered:', result);
|
||||
} catch (error) {
|
||||
console.error('Registration failed:', error);
|
||||
}
|
||||
});
|
||||
|
||||
// Login with passkey
|
||||
document.getElementById('login-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await loginWithPasskey('alice');
|
||||
console.log('Logged in:', result);
|
||||
} catch (error) {
|
||||
console.error('Login failed:', error);
|
||||
}
|
||||
});
|
||||
`
|
||||
}
|
||||
405
pkg/security/passkey_provider.go
Normal file
405
pkg/security/passkey_provider.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
type DatabasePasskeyProviderOptions struct {
|
||||
// RPID is the Relying Party ID (typically your domain, e.g., "example.com")
|
||||
RPID string
|
||||
// RPName is the display name for your relying party
|
||||
RPName string
|
||||
// RPOrigin is the expected origin (e.g., "https://example.com")
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) *DatabasePasskeyProvider {
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
challenge := make([]byte, 32)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
|
||||
// Get existing credentials to exclude
|
||||
credentials, err := p.GetCredentials(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get existing credentials: %w", err)
|
||||
}
|
||||
|
||||
excludeCredentials := make([]PasskeyCredentialDescriptor, 0, len(credentials))
|
||||
for i := range credentials {
|
||||
excludeCredentials = append(excludeCredentials, PasskeyCredentialDescriptor{
|
||||
Type: "public-key",
|
||||
ID: credentials[i].CredentialID,
|
||||
Transports: credentials[i].Transports,
|
||||
})
|
||||
}
|
||||
|
||||
// Create user handle (persistent user ID)
|
||||
userHandle := []byte(fmt.Sprintf("user_%d", userID))
|
||||
|
||||
return &PasskeyRegistrationOptions{
|
||||
Challenge: challenge,
|
||||
RelyingParty: PasskeyRelyingParty{
|
||||
ID: p.rpID,
|
||||
Name: p.rpName,
|
||||
},
|
||||
User: PasskeyUser{
|
||||
ID: userHandle,
|
||||
Name: username,
|
||||
DisplayName: displayName,
|
||||
},
|
||||
PubKeyCredParams: []PasskeyCredentialParam{
|
||||
{Type: "public-key", Alg: -7}, // ES256 (ECDSA with SHA-256)
|
||||
{Type: "public-key", Alg: -257}, // RS256 (RSASSA-PKCS1-v1_5 with SHA-256)
|
||||
},
|
||||
Timeout: p.timeout,
|
||||
ExcludeCredentials: excludeCredentials,
|
||||
AuthenticatorSelection: &PasskeyAuthenticatorSelection{
|
||||
RequireResidentKey: false,
|
||||
ResidentKey: "preferred",
|
||||
UserVerification: "preferred",
|
||||
},
|
||||
Attestation: "none",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteRegistration verifies and stores a new passkey credential
|
||||
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||
// like github.com/go-webauthn/webauthn to properly verify attestation and parse credentials.
|
||||
func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error) {
|
||||
// TODO: Implement full WebAuthn verification
|
||||
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||
// 2. Parse and verify attestationObject
|
||||
// 3. Extract public key and credential ID
|
||||
// 4. Verify attestation signature (if not "none")
|
||||
|
||||
// For now, this is a placeholder that stores the credential data
|
||||
// In production, you MUST use a proper WebAuthn library
|
||||
|
||||
credData := map[string]any{
|
||||
"user_id": userID,
|
||||
"credential_id": base64.StdEncoding.EncodeToString(response.RawID),
|
||||
"public_key": base64.StdEncoding.EncodeToString(response.Response.AttestationObject),
|
||||
"attestation_type": "none",
|
||||
"sign_count": 0,
|
||||
"transports": response.Transports,
|
||||
"backup_eligible": false,
|
||||
"backup_state": false,
|
||||
"name": "Passkey",
|
||||
}
|
||||
|
||||
credJSON, err := json.Marshal(credData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal credential data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to store credential")
|
||||
}
|
||||
|
||||
return &PasskeyCredential{
|
||||
ID: fmt.Sprintf("%d", credentialID.Int64),
|
||||
UserID: userID,
|
||||
CredentialID: response.RawID,
|
||||
PublicKey: response.Response.AttestationObject,
|
||||
AttestationType: "none",
|
||||
Transports: response.Transports,
|
||||
CreatedAt: time.Now(),
|
||||
LastUsedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginAuthentication creates authentication options for passkey login
|
||||
func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error) {
|
||||
// Generate challenge
|
||||
challenge := make([]byte, 32)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
|
||||
// If username is provided, get user's credentials
|
||||
var allowCredentials []PasskeyCredentialDescriptor
|
||||
if username != "" {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get credentials")
|
||||
}
|
||||
|
||||
// Parse credentials
|
||||
var creds []struct {
|
||||
ID string `json:"credential_id"`
|
||||
Transports []string `json:"transports"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(credentialsJSON.String), &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||
}
|
||||
|
||||
allowCredentials = make([]PasskeyCredentialDescriptor, 0, len(creds))
|
||||
for _, cred := range creds {
|
||||
credID, err := base64.StdEncoding.DecodeString(cred.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
allowCredentials = append(allowCredentials, PasskeyCredentialDescriptor{
|
||||
Type: "public-key",
|
||||
ID: credID,
|
||||
Transports: cred.Transports,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &PasskeyAuthenticationOptions{
|
||||
Challenge: challenge,
|
||||
Timeout: p.timeout,
|
||||
RelyingPartyID: p.rpID,
|
||||
AllowCredentials: allowCredentials,
|
||||
UserVerification: "preferred",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteAuthentication verifies a passkey assertion and returns the user ID
|
||||
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||
// like github.com/go-webauthn/webauthn to properly verify the assertion signature.
|
||||
func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error) {
|
||||
// TODO: Implement full WebAuthn verification
|
||||
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||
// 2. Verify authenticatorData
|
||||
// 3. Verify signature using stored public key
|
||||
// 4. Update sign counter and check for cloning
|
||||
|
||||
// Get credential from database
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return 0, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return 0, fmt.Errorf("credential not found")
|
||||
}
|
||||
|
||||
// Parse credential
|
||||
var cred struct {
|
||||
UserID int `json:"user_id"`
|
||||
SignCount uint32 `json:"sign_count"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(credentialJSON.String), &cred); err != nil {
|
||||
return 0, fmt.Errorf("failed to parse credential: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Verify signature here
|
||||
// For now, we'll just update the counter as a placeholder
|
||||
|
||||
// Update counter (in production, this should be done after successful verification)
|
||||
newCounter := cred.SignCount + 1
|
||||
var updateSuccess bool
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
|
||||
if cloneWarning.Valid && cloneWarning.Bool {
|
||||
return 0, fmt.Errorf("credential cloning detected")
|
||||
}
|
||||
|
||||
return cred.UserID, nil
|
||||
}
|
||||
|
||||
// GetCredentials returns all passkey credentials for a user
|
||||
func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get credentials")
|
||||
}
|
||||
|
||||
// Parse credentials
|
||||
var rawCreds []struct {
|
||||
ID int `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
CredentialID string `json:"credential_id"`
|
||||
PublicKey string `json:"public_key"`
|
||||
AttestationType string `json:"attestation_type"`
|
||||
AAGUID string `json:"aaguid"`
|
||||
SignCount uint32 `json:"sign_count"`
|
||||
CloneWarning bool `json:"clone_warning"`
|
||||
Transports []string `json:"transports"`
|
||||
BackupEligible bool `json:"backup_eligible"`
|
||||
BackupState bool `json:"backup_state"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(credentialsJSON.String), &rawCreds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||
}
|
||||
|
||||
credentials := make([]PasskeyCredential, 0, len(rawCreds))
|
||||
for i := range rawCreds {
|
||||
raw := rawCreds[i]
|
||||
credID, err := base64.StdEncoding.DecodeString(raw.CredentialID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pubKey, err := base64.StdEncoding.DecodeString(raw.PublicKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
aaguid, _ := base64.StdEncoding.DecodeString(raw.AAGUID)
|
||||
|
||||
credentials = append(credentials, PasskeyCredential{
|
||||
ID: fmt.Sprintf("%d", raw.ID),
|
||||
UserID: raw.UserID,
|
||||
CredentialID: credID,
|
||||
PublicKey: pubKey,
|
||||
AttestationType: raw.AttestationType,
|
||||
AAGUID: aaguid,
|
||||
SignCount: raw.SignCount,
|
||||
CloneWarning: raw.CloneWarning,
|
||||
Transports: raw.Transports,
|
||||
BackupEligible: raw.BackupEligible,
|
||||
BackupState: raw.BackupState,
|
||||
Name: raw.Name,
|
||||
CreatedAt: raw.CreatedAt,
|
||||
LastUsedAt: raw.LastUsedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
// DeleteCredential removes a passkey credential
|
||||
func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID int, credentialID string) error {
|
||||
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credential ID: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to delete credential")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCredentialName updates the friendly name of a credential
|
||||
func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credential ID: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to update credential name")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
330
pkg/security/passkey_test.go
Normal file
330
pkg/security/passkey_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
func TestDatabasePasskeyProvider_BeginRegistration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock get credentials query
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := provider.BeginRegistration(ctx, 1, "testuser", "Test User")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginRegistration failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.RelyingParty.ID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingParty.ID)
|
||||
}
|
||||
|
||||
if opts.User.Name != "testuser" {
|
||||
t.Errorf("expected username 'testuser', got '%s'", opts.User.Name)
|
||||
}
|
||||
|
||||
if len(opts.Challenge) != 32 {
|
||||
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||
}
|
||||
|
||||
if len(opts.PubKeyCredParams) != 2 {
|
||||
t.Errorf("expected 2 credential params, got %d", len(opts.PubKeyCredParams))
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_BeginAuthentication(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock get credentials by username query
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user_id", "p_credentials"}).
|
||||
AddRow(true, nil, 1, `[{"credential_id":"YWJjZGVm","transports":["internal"]}]`)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username`).
|
||||
WithArgs("testuser").
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := provider.BeginAuthentication(ctx, "testuser")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginAuthentication failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.RelyingPartyID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingPartyID)
|
||||
}
|
||||
|
||||
if len(opts.Challenge) != 32 {
|
||||
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||
}
|
||||
|
||||
if len(opts.AllowCredentials) != 1 {
|
||||
t.Errorf("expected 1 allowed credential, got %d", len(opts.AllowCredentials))
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_GetCredentials(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
credentialsJSON := `[{
|
||||
"id": 1,
|
||||
"user_id": 1,
|
||||
"credential_id": "YWJjZGVmMTIzNDU2",
|
||||
"public_key": "cHVibGlja2V5",
|
||||
"attestation_type": "none",
|
||||
"aaguid": "",
|
||||
"sign_count": 5,
|
||||
"clone_warning": false,
|
||||
"transports": ["internal"],
|
||||
"backup_eligible": true,
|
||||
"backup_state": false,
|
||||
"name": "My Phone",
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"last_used_at": "2026-01-31T00:00:00Z"
|
||||
}]`
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, credentialsJSON)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
credentials, err := provider.GetCredentials(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if len(credentials) != 1 {
|
||||
t.Fatalf("expected 1 credential, got %d", len(credentials))
|
||||
}
|
||||
|
||||
cred := credentials[0]
|
||||
if cred.UserID != 1 {
|
||||
t.Errorf("expected user ID 1, got %d", cred.UserID)
|
||||
}
|
||||
if cred.Name != "My Phone" {
|
||||
t.Errorf("expected name 'My Phone', got '%s'", cred.Name)
|
||||
}
|
||||
if cred.SignCount != 5 {
|
||||
t.Errorf("expected sign count 5, got %d", cred.SignCount)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_DeleteCredential(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||
AddRow(true, nil)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_delete_credential`).
|
||||
WithArgs(1, sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
err = provider.DeleteCredential(ctx, 1, "YWJjZGVmMTIzNDU2")
|
||||
if err != nil {
|
||||
t.Errorf("DeleteCredential failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_UpdateCredentialName(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||
AddRow(true, nil)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_update_name`).
|
||||
WithArgs(1, sqlmock.AnyArg(), "New Name").
|
||||
WillReturnRows(rows)
|
||||
|
||||
err = provider.UpdateCredentialName(ctx, 1, "YWJjZGVmMTIzNDU2", "New Name")
|
||||
if err != nil {
|
||||
t.Errorf("UpdateCredentialName failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticator_PasskeyMethods(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BeginPasskeyRegistration", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
DisplayName: "Test User",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("BeginPasskeyRegistration failed: %v", err)
|
||||
}
|
||||
|
||||
if opts == nil {
|
||||
t.Error("expected options, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetPasskeyCredentials", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
credentials, err := auth.GetPasskeyCredentials(ctx, 1)
|
||||
if err != nil {
|
||||
t.Errorf("GetPasskeyCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if credentials == nil {
|
||||
t.Error("expected credentials slice, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticator_WithoutPasskey(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
DisplayName: "Test User",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error when passkey provider not configured, got nil")
|
||||
}
|
||||
|
||||
expectedMsg := "passkey provider not configured"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("expected error '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasskeyProvider_NilDB(t *testing.T) {
|
||||
// This test verifies that the provider can be created with nil DB
|
||||
// but operations will fail. In production, always provide a valid DB.
|
||||
var db *sql.DB
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
if provider == nil {
|
||||
t.Error("expected provider to be created even with nil DB")
|
||||
}
|
||||
|
||||
// Verify that the provider has the correct configuration
|
||||
if provider.rpID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", provider.rpID)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
@@ -60,10 +61,19 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
oauth2ProvidersMutex sync.RWMutex
|
||||
|
||||
// Passkey provider (optional)
|
||||
passkeyProvider PasskeyProvider
|
||||
}
|
||||
|
||||
// DatabaseAuthenticatorOptions configures the database authenticator
|
||||
@@ -73,6 +83,8 @@ type DatabaseAuthenticatorOptions struct {
|
||||
CacheTTL time.Duration
|
||||
// Cache is an optional cache instance. If nil, uses the default cache
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -92,9 +104,10 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
}
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +145,41 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// Register implements Registrable interface
|
||||
func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error) {
|
||||
// Convert RegisterRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("registration failed")
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse register response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Convert LogoutRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -654,3 +702,135 @@ func generateRandomString(length int) string {
|
||||
// }
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// Passkey authentication methods
|
||||
// ==============================
|
||||
|
||||
// WithPasskey configures the DatabaseAuthenticator with a passkey provider
|
||||
func (a *DatabaseAuthenticator) WithPasskey(provider PasskeyProvider) *DatabaseAuthenticator {
|
||||
a.passkeyProvider = provider
|
||||
return a
|
||||
}
|
||||
|
||||
// BeginPasskeyRegistration initiates passkey registration for a user
|
||||
func (a *DatabaseAuthenticator) BeginPasskeyRegistration(ctx context.Context, req PasskeyBeginRegistrationRequest) (*PasskeyRegistrationOptions, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.BeginRegistration(ctx, req.UserID, req.Username, req.DisplayName)
|
||||
}
|
||||
|
||||
// CompletePasskeyRegistration completes passkey registration
|
||||
func (a *DatabaseAuthenticator) CompletePasskeyRegistration(ctx context.Context, req PasskeyRegisterRequest) (*PasskeyCredential, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
|
||||
cred, err := a.passkeyProvider.CompleteRegistration(ctx, req.UserID, req.Response, req.ExpectedChallenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update credential name if provided
|
||||
if req.CredentialName != "" && cred.ID != "" {
|
||||
_ = a.passkeyProvider.UpdateCredentialName(ctx, req.UserID, cred.ID, req.CredentialName)
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// BeginPasskeyAuthentication initiates passkey authentication
|
||||
func (a *DatabaseAuthenticator) BeginPasskeyAuthentication(ctx context.Context, req PasskeyBeginAuthenticationRequest) (*PasskeyAuthenticationOptions, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.BeginAuthentication(ctx, req.Username)
|
||||
}
|
||||
|
||||
// LoginWithPasskey authenticates a user using a passkey and creates a session
|
||||
func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req PasskeyLoginRequest) (*LoginResponse, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
|
||||
// Verify passkey assertion
|
||||
userID, err := a.passkeyProvider.CompleteAuthentication(ctx, req.Response, req.ExpectedChallenge)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
func (a *DatabaseAuthenticator) GetPasskeyCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.GetCredentials(ctx, userID)
|
||||
}
|
||||
|
||||
// DeletePasskeyCredential removes a passkey credential
|
||||
func (a *DatabaseAuthenticator) DeletePasskeyCredential(ctx context.Context, userID int, credentialID string) error {
|
||||
if a.passkeyProvider == nil {
|
||||
return fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.DeleteCredential(ctx, userID, credentialID)
|
||||
}
|
||||
|
||||
// UpdatePasskeyCredentialName updates the friendly name of a credential
|
||||
func (a *DatabaseAuthenticator) UpdatePasskeyCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||
if a.passkeyProvider == nil {
|
||||
return fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.UpdateCredentialName(ctx, userID, credentialID, name)
|
||||
}
|
||||
|
||||
@@ -635,6 +635,94 @@ func TestDatabaseAuthenticator(t *testing.T) {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful registration", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "newuser",
|
||||
Password: "password123",
|
||||
Email: "newuser@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"abc123","user":{"user_id":1,"user_name":"newuser","email":"newuser@example.com"},"expires_in":86400}`)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := auth.Register(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if resp.Token != "abc123" {
|
||||
t.Errorf("expected token abc123, got %s", resp.Token)
|
||||
}
|
||||
if resp.User.UserName != "newuser" {
|
||||
t.Errorf("expected username newuser, got %s", resp.User.UserName)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate username", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "existinguser",
|
||||
Password: "password123",
|
||||
Email: "new@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(false, "Username already exists", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := auth.Register(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate username")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate email", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "newuser2",
|
||||
Password: "password123",
|
||||
Email: "existing@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(false, "Email already exists", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := auth.Register(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate email")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test DatabaseAuthenticator RefreshToken
|
||||
|
||||
188
pkg/security/totp.go
Normal file
188
pkg/security/totp.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"math"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TwoFactorAuthProvider defines interface for 2FA operations
|
||||
type TwoFactorAuthProvider interface {
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error)
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
Validate2FACode(secret string, code string) (bool, error)
|
||||
|
||||
// Enable2FA activates 2FA for a user (store secret in your database)
|
||||
Enable2FA(userID int, secret string, backupCodes []string) error
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
Disable2FA(userID int) error
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
Get2FAStatus(userID int) (bool, error)
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
Get2FASecret(userID int) (string, error)
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
GenerateBackupCodes(userID int, count int) ([]string, error)
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
ValidateBackupCode(userID int, code string) (bool, error)
|
||||
}
|
||||
|
||||
// TwoFactorSecret contains 2FA setup information
|
||||
type TwoFactorSecret struct {
|
||||
Secret string `json:"secret"` // Base32 encoded secret
|
||||
QRCodeURL string `json:"qr_code_url"` // URL for QR code generation
|
||||
BackupCodes []string `json:"backup_codes"` // One-time backup codes
|
||||
Issuer string `json:"issuer"` // Application name
|
||||
AccountName string `json:"account_name"` // User identifier (email/username)
|
||||
}
|
||||
|
||||
// TwoFactorConfig holds TOTP configuration
|
||||
type TwoFactorConfig struct {
|
||||
Algorithm string // SHA1, SHA256, SHA512
|
||||
Digits int // Number of digits in code (6 or 8)
|
||||
Period int // Time step in seconds (default 30)
|
||||
SkewWindow int // Number of time steps to check before/after (default 1)
|
||||
}
|
||||
|
||||
// DefaultTwoFactorConfig returns standard TOTP configuration
|
||||
func DefaultTwoFactorConfig() *TwoFactorConfig {
|
||||
return &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// TOTPGenerator handles TOTP code generation and validation
|
||||
type TOTPGenerator struct {
|
||||
config *TwoFactorConfig
|
||||
}
|
||||
|
||||
// NewTOTPGenerator creates a new TOTP generator with config
|
||||
func NewTOTPGenerator(config *TwoFactorConfig) *TOTPGenerator {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &TOTPGenerator{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecret creates a random base32-encoded secret
|
||||
func (t *TOTPGenerator) GenerateSecret() (string, error) {
|
||||
secret := make([]byte, 20)
|
||||
_, err := rand.Read(secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random secret: %w", err)
|
||||
}
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
|
||||
}
|
||||
|
||||
// GenerateQRCodeURL creates a URL for QR code generation
|
||||
func (t *TOTPGenerator) GenerateQRCodeURL(secret, issuer, accountName string) string {
|
||||
params := url.Values{}
|
||||
params.Set("secret", secret)
|
||||
params.Set("issuer", issuer)
|
||||
params.Set("algorithm", t.config.Algorithm)
|
||||
params.Set("digits", fmt.Sprintf("%d", t.config.Digits))
|
||||
params.Set("period", fmt.Sprintf("%d", t.config.Period))
|
||||
|
||||
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, accountName))
|
||||
return fmt.Sprintf("otpauth://totp/%s?%s", label, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateCode creates a TOTP code for a given time
|
||||
func (t *TOTPGenerator) GenerateCode(secret string, timestamp time.Time) (string, error) {
|
||||
// Decode secret
|
||||
key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid secret: %w", err)
|
||||
}
|
||||
|
||||
// Calculate counter (time steps since Unix epoch)
|
||||
counter := uint64(timestamp.Unix()) / uint64(t.config.Period)
|
||||
|
||||
// Generate HMAC
|
||||
h := t.getHashFunc()
|
||||
mac := hmac.New(h, key)
|
||||
|
||||
// Convert counter to 8-byte array
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, counter)
|
||||
mac.Write(buf)
|
||||
|
||||
sum := mac.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := sum[len(sum)-1] & 0x0f
|
||||
truncated := binary.BigEndian.Uint32(sum[offset:]) & 0x7fffffff
|
||||
|
||||
// Generate code with specified digits
|
||||
code := truncated % uint32(math.Pow10(t.config.Digits))
|
||||
|
||||
format := fmt.Sprintf("%%0%dd", t.config.Digits)
|
||||
return fmt.Sprintf(format, code), nil
|
||||
}
|
||||
|
||||
// ValidateCode checks if a code is valid for the secret
|
||||
func (t *TOTPGenerator) ValidateCode(secret, code string) (bool, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Check current time and skew window
|
||||
for i := -t.config.SkewWindow; i <= t.config.SkewWindow; i++ {
|
||||
timestamp := now.Add(time.Duration(i*t.config.Period) * time.Second)
|
||||
expected, err := t.GenerateCode(secret, timestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if code == expected {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// getHashFunc returns the hash function based on algorithm
|
||||
func (t *TOTPGenerator) getHashFunc() func() hash.Hash {
|
||||
switch strings.ToUpper(t.config.Algorithm) {
|
||||
case "SHA256":
|
||||
return sha256.New
|
||||
case "SHA512":
|
||||
return sha512.New
|
||||
default:
|
||||
return sha1.New
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates random backup codes
|
||||
func GenerateBackupCodes(count int) ([]string, error) {
|
||||
codes := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
code := make([]byte, 4)
|
||||
_, err := rand.Read(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup code: %w", err)
|
||||
}
|
||||
codes[i] = fmt.Sprintf("%08X", binary.BigEndian.Uint32(code))
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
399
pkg/security/totp_integration_test.go
Normal file
399
pkg/security/totp_integration_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
// MockAuthenticator is a simple authenticator for testing 2FA
|
||||
type MockAuthenticator struct {
|
||||
users map[string]*security.UserContext
|
||||
}
|
||||
|
||||
func NewMockAuthenticator() *MockAuthenticator {
|
||||
return &MockAuthenticator{
|
||||
users: map[string]*security.UserContext{
|
||||
"testuser": {
|
||||
UserID: 1,
|
||||
UserName: "testuser",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
|
||||
user, exists := m.users[req.Username]
|
||||
if !exists || req.Password != "password" {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return &security.LoginResponse{
|
||||
Token: "mock-token",
|
||||
RefreshToken: "mock-refresh-token",
|
||||
User: user,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
return m.users["testuser"], nil
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Setup(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
if secret.Secret == "" {
|
||||
t.Error("Setup2FA() returned empty secret")
|
||||
}
|
||||
|
||||
if secret.QRCodeURL == "" {
|
||||
t.Error("Setup2FA() returned empty QR code URL")
|
||||
}
|
||||
|
||||
if len(secret.BackupCodes) == 0 {
|
||||
t.Error("Setup2FA() returned no backup codes")
|
||||
}
|
||||
|
||||
if secret.Issuer != "TestApp" {
|
||||
t.Errorf("Setup2FA() Issuer = %s, want TestApp", secret.Issuer)
|
||||
}
|
||||
|
||||
if secret.AccountName != "test@example.com" {
|
||||
t.Errorf("Setup2FA() AccountName = %s, want test@example.com", secret.AccountName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Enable2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Generate valid code
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, err := totp.GenerateCode(secret.Secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable 2FA with valid code
|
||||
err = tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
if err != nil {
|
||||
t.Errorf("Enable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify 2FA is enabled
|
||||
status, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if !status {
|
||||
t.Error("Enable2FA() did not enable 2FA")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Enable2FA_InvalidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Try to enable with invalid code
|
||||
err = tfaAuth.Enable2FA(1, secret.Secret, "000000")
|
||||
if err == nil {
|
||||
t.Error("Enable2FA() should fail with invalid code")
|
||||
}
|
||||
|
||||
// Verify 2FA is not enabled
|
||||
status, _ := provider.Get2FAStatus(1)
|
||||
if status {
|
||||
t.Error("Enable2FA() should not enable 2FA with invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_Without2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA when not enabled")
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when 2FA not required")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_NoCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Try to login without 2FA code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if !resp.Requires2FA {
|
||||
t.Error("Login() should require 2FA when enabled")
|
||||
}
|
||||
|
||||
if resp.Token != "" {
|
||||
t.Error("Login() should not return token when 2FA required but not provided")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_ValidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Generate new valid code for login
|
||||
newCode, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
|
||||
// Login with 2FA code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: newCode,
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA when valid code provided")
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when 2FA validated")
|
||||
}
|
||||
|
||||
if !resp.User.TwoFactorEnabled {
|
||||
t.Error("Login() should set TwoFactorEnabled on user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_InvalidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Try to login with invalid code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: "000000",
|
||||
}
|
||||
|
||||
_, err := tfaAuth.Login(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail with invalid 2FA code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_WithBackupCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Get backup codes
|
||||
backupCodes, _ := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
|
||||
// Login with backup code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: backupCodes[0],
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() with backup code error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when backup code validated")
|
||||
}
|
||||
|
||||
// Try to use same backup code again
|
||||
req2 := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: backupCodes[0],
|
||||
}
|
||||
|
||||
_, err = tfaAuth.Login(context.Background(), req2)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail when reusing backup code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Disable2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Disable 2FA
|
||||
err := tfaAuth.Disable2FA(1)
|
||||
if err != nil {
|
||||
t.Errorf("Disable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify 2FA is disabled
|
||||
status, _ := provider.Get2FAStatus(1)
|
||||
if status {
|
||||
t.Error("Disable2FA() did not disable 2FA")
|
||||
}
|
||||
|
||||
// Login should not require 2FA
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA after disabling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_RegenerateBackupCodes(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Get initial backup codes
|
||||
codes1, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(codes1) != 10 {
|
||||
t.Errorf("RegenerateBackupCodes() returned %d codes, want 10", len(codes1))
|
||||
}
|
||||
|
||||
// Regenerate backup codes
|
||||
codes2, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
// Old codes should not work
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: codes1[0],
|
||||
}
|
||||
|
||||
_, err = tfaAuth.Login(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail with old backup code after regeneration")
|
||||
}
|
||||
|
||||
// New codes should work
|
||||
req2 := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: codes2[0],
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() with new backup code error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token with new backup code")
|
||||
}
|
||||
}
|
||||
134
pkg/security/totp_middleware.go
Normal file
134
pkg/security/totp_middleware.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TwoFactorAuthenticator wraps an Authenticator and adds 2FA support
|
||||
type TwoFactorAuthenticator struct {
|
||||
baseAuth Authenticator
|
||||
totp *TOTPGenerator
|
||||
provider TwoFactorAuthProvider
|
||||
}
|
||||
|
||||
// NewTwoFactorAuthenticator creates a new 2FA-enabled authenticator
|
||||
func NewTwoFactorAuthenticator(baseAuth Authenticator, provider TwoFactorAuthProvider, config *TwoFactorConfig) *TwoFactorAuthenticator {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &TwoFactorAuthenticator{
|
||||
baseAuth: baseAuth,
|
||||
totp: NewTOTPGenerator(config),
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates with 2FA support
|
||||
func (t *TwoFactorAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// First, perform standard authentication
|
||||
resp, err := t.baseAuth.Login(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if user has 2FA enabled
|
||||
if resp.User == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
has2FA, err := t.provider.Get2FAStatus(resp.User.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check 2FA status: %w", err)
|
||||
}
|
||||
|
||||
if !has2FA {
|
||||
// User doesn't have 2FA enabled, return normal response
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// User has 2FA enabled
|
||||
if req.TwoFactorCode == "" {
|
||||
// No 2FA code provided, require it
|
||||
resp.Requires2FA = true
|
||||
resp.Token = "" // Don't return token until 2FA is verified
|
||||
resp.RefreshToken = ""
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Validate 2FA code
|
||||
secret, err := t.provider.Get2FASecret(resp.User.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get 2FA secret: %w", err)
|
||||
}
|
||||
|
||||
// Try TOTP code first
|
||||
valid, err := t.totp.ValidateCode(secret, req.TwoFactorCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate 2FA code: %w", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
// Try backup code
|
||||
valid, err = t.provider.ValidateBackupCode(resp.User.UserID, req.TwoFactorCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate backup code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid 2FA code")
|
||||
}
|
||||
|
||||
// 2FA verified, return full response with token
|
||||
resp.User.TwoFactorEnabled = true
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Logout delegates to base authenticator
|
||||
func (t *TwoFactorAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
return t.baseAuth.Logout(ctx, req)
|
||||
}
|
||||
|
||||
// Authenticate delegates to base authenticator
|
||||
func (t *TwoFactorAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
return t.baseAuth.Authenticate(r)
|
||||
}
|
||||
|
||||
// Setup2FA initiates 2FA setup for a user
|
||||
func (t *TwoFactorAuthenticator) Setup2FA(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
return t.provider.Generate2FASecret(userID, issuer, accountName)
|
||||
}
|
||||
|
||||
// Enable2FA completes 2FA setup after user confirms with a valid code
|
||||
func (t *TwoFactorAuthenticator) Enable2FA(userID int, secret, verificationCode string) error {
|
||||
// Verify the code before enabling
|
||||
valid, err := t.totp.ValidateCode(secret, verificationCode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate code: %w", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid verification code")
|
||||
}
|
||||
|
||||
// Generate backup codes
|
||||
backupCodes, err := t.provider.GenerateBackupCodes(userID, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Enable 2FA
|
||||
return t.provider.Enable2FA(userID, secret, backupCodes)
|
||||
}
|
||||
|
||||
// Disable2FA removes 2FA from a user account
|
||||
func (t *TwoFactorAuthenticator) Disable2FA(userID int) error {
|
||||
return t.provider.Disable2FA(userID)
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes creates new backup codes for a user
|
||||
func (t *TwoFactorAuthenticator) RegenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
return t.provider.GenerateBackupCodes(userID, count)
|
||||
}
|
||||
229
pkg/security/totp_provider_database.go
Normal file
229
pkg/security/totp_provider_database.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
func (p *DatabaseTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
secret, err := p.totpGen.GenerateSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate secret: %w", err)
|
||||
}
|
||||
|
||||
qrURL := p.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
backupCodes, err := GenerateBackupCodes(10)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
return &TwoFactorSecret{
|
||||
Secret: secret,
|
||||
QRCodeURL: qrURL,
|
||||
BackupCodes: backupCodes,
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
func (p *DatabaseTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||
return p.totpGen.ValidateCode(secret, code)
|
||||
}
|
||||
|
||||
// Enable2FA activates 2FA for a user
|
||||
func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
// Hash backup codes for secure storage
|
||||
hashedCodes := make([]string, len(backupCodes))
|
||||
for i, code := range backupCodes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Convert to JSON array
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Call stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to enable 2FA")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to disable 2FA")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return false, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return false, fmt.Errorf("failed to get 2FA status")
|
||||
}
|
||||
|
||||
return enabled, nil
|
||||
}
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return "", fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return "", fmt.Errorf("failed to get 2FA secret")
|
||||
}
|
||||
|
||||
if !secret.Valid {
|
||||
return "", fmt.Errorf("2FA secret not found")
|
||||
}
|
||||
|
||||
return secret.String, nil
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Hash backup codes for storage
|
||||
hashedCodes := make([]string, len(codes))
|
||||
for i, code := range codes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Convert to JSON array
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Call stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to regenerate backup codes")
|
||||
}
|
||||
|
||||
// Return unhashed codes to user (only time they see them)
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||
// Hash the code
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
codeHash := hex.EncodeToString(hash[:])
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return false, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
218
pkg/security/totp_provider_database_test.go
Normal file
218
pkg/security/totp_provider_database_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Note: These tests require a PostgreSQL database with the schema from totp_database_schema.sql
|
||||
// Set TEST_DATABASE_URL environment variable or skip tests
|
||||
|
||||
func setupTestDB(t *testing.T) *sql.DB {
|
||||
// Skip if no test database configured
|
||||
t.Skip("Database tests require TEST_DATABASE_URL environment variable")
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Enable2FA(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Generate secret and backup codes
|
||||
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable 2FA
|
||||
err = provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
if err != nil {
|
||||
t.Errorf("Enable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify enabled
|
||||
enabled, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
t.Error("Get2FAStatus() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Disable2FA(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable first
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Disable
|
||||
err := provider.Disable2FA(1)
|
||||
if err != nil {
|
||||
t.Errorf("Disable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify disabled
|
||||
enabled, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if enabled {
|
||||
t.Error("Get2FAStatus() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_GetSecret(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Retrieve secret
|
||||
retrieved, err := provider.Get2FASecret(1)
|
||||
if err != nil {
|
||||
t.Errorf("Get2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved != secret.Secret {
|
||||
t.Errorf("Get2FASecret() = %v, want %v", retrieved, secret.Secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_ValidateBackupCode(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Validate backup code
|
||||
valid, err := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateBackupCode() = false, want true")
|
||||
}
|
||||
|
||||
// Try to use same code again
|
||||
valid, err = provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if err == nil {
|
||||
t.Error("ValidateBackupCode() should error on reuse")
|
||||
}
|
||||
|
||||
// Try invalid code
|
||||
valid, err = provider.ValidateBackupCode(1, "INVALID")
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("ValidateBackupCode() = true for invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_RegenerateBackupCodes(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Regenerate codes
|
||||
newCodes, err := provider.GenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Errorf("GenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(newCodes) != 10 {
|
||||
t.Errorf("GenerateBackupCodes() returned %d codes, want 10", len(newCodes))
|
||||
}
|
||||
|
||||
// Old codes should not work
|
||||
valid, _ := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if valid {
|
||||
t.Error("Old backup code should not work after regeneration")
|
||||
}
|
||||
|
||||
// New codes should work
|
||||
valid, err = provider.ValidateBackupCode(1, newCodes[0])
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateBackupCode() = false for new code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Generate2FASecret(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
if secret.Secret == "" {
|
||||
t.Error("Generate2FASecret() returned empty secret")
|
||||
}
|
||||
|
||||
if secret.QRCodeURL == "" {
|
||||
t.Error("Generate2FASecret() returned empty QR code URL")
|
||||
}
|
||||
|
||||
if len(secret.BackupCodes) != 10 {
|
||||
t.Errorf("Generate2FASecret() returned %d backup codes, want 10", len(secret.BackupCodes))
|
||||
}
|
||||
|
||||
if secret.Issuer != "TestApp" {
|
||||
t.Errorf("Generate2FASecret() Issuer = %v, want TestApp", secret.Issuer)
|
||||
}
|
||||
|
||||
if secret.AccountName != "test@example.com" {
|
||||
t.Errorf("Generate2FASecret() AccountName = %v, want test@example.com", secret.AccountName)
|
||||
}
|
||||
}
|
||||
156
pkg/security/totp_provider_memory.go
Normal file
156
pkg/security/totp_provider_memory.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryTwoFactorProvider is an in-memory implementation of TwoFactorAuthProvider for testing/examples
|
||||
type MemoryTwoFactorProvider struct {
|
||||
mu sync.RWMutex
|
||||
secrets map[int]string // userID -> secret
|
||||
backupCodes map[int]map[string]bool // userID -> backup codes (code -> used)
|
||||
totpGen *TOTPGenerator
|
||||
}
|
||||
|
||||
// NewMemoryTwoFactorProvider creates a new in-memory 2FA provider
|
||||
func NewMemoryTwoFactorProvider(config *TwoFactorConfig) *MemoryTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &MemoryTwoFactorProvider{
|
||||
secrets: make(map[int]string),
|
||||
backupCodes: make(map[int]map[string]bool),
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
func (m *MemoryTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
secret, err := m.totpGen.GenerateSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qrURL := m.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
backupCodes, err := GenerateBackupCodes(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TwoFactorSecret{
|
||||
Secret: secret,
|
||||
QRCodeURL: qrURL,
|
||||
BackupCodes: backupCodes,
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
func (m *MemoryTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||
return m.totpGen.ValidateCode(secret, code)
|
||||
}
|
||||
|
||||
// Enable2FA activates 2FA for a user
|
||||
func (m *MemoryTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.secrets[userID] = secret
|
||||
|
||||
// Store backup codes
|
||||
if m.backupCodes[userID] == nil {
|
||||
m.backupCodes[userID] = make(map[string]bool)
|
||||
}
|
||||
|
||||
for _, code := range backupCodes {
|
||||
// Hash backup codes for security
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
func (m *MemoryTwoFactorProvider) Disable2FA(userID int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.secrets, userID)
|
||||
delete(m.backupCodes, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
func (m *MemoryTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
_, exists := m.secrets[userID]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
func (m *MemoryTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
secret, exists := m.secrets[userID]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("user does not have 2FA enabled")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
func (m *MemoryTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Clear old backup codes and store new ones
|
||||
m.backupCodes[userID] = make(map[string]bool)
|
||||
for _, code := range codes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
func (m *MemoryTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userCodes, exists := m.backupCodes[userID]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Hash the provided code
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
|
||||
used, exists := userCodes[hashStr]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if used {
|
||||
return false, fmt.Errorf("backup code already used")
|
||||
}
|
||||
|
||||
// Mark as used
|
||||
userCodes[hashStr] = true
|
||||
return true, nil
|
||||
}
|
||||
292
pkg/security/totp_test.go
Normal file
292
pkg/security/totp_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTOTPGenerator_GenerateSecret(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
secret, err := totp.GenerateSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSecret() error = %v", err)
|
||||
}
|
||||
|
||||
if secret == "" {
|
||||
t.Error("GenerateSecret() returned empty secret")
|
||||
}
|
||||
|
||||
// Secret should be base32 encoded
|
||||
if len(secret) < 16 {
|
||||
t.Error("GenerateSecret() returned secret that is too short")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_GenerateQRCodeURL(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
issuer := "TestApp"
|
||||
accountName := "user@example.com"
|
||||
|
||||
url := totp.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
if !strings.HasPrefix(url, "otpauth://totp/") {
|
||||
t.Errorf("GenerateQRCodeURL() = %v, want otpauth://totp/ prefix", url)
|
||||
}
|
||||
|
||||
if !strings.Contains(url, "secret="+secret) {
|
||||
t.Errorf("GenerateQRCodeURL() missing secret parameter")
|
||||
}
|
||||
|
||||
if !strings.Contains(url, "issuer="+issuer) {
|
||||
t.Errorf("GenerateQRCodeURL() missing issuer parameter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_GenerateCode(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Test with known time
|
||||
timestamp := time.Unix(1234567890, 0)
|
||||
code, err := totp.GenerateCode(secret, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if len(code) != 6 {
|
||||
t.Errorf("GenerateCode() returned code with length %d, want 6", len(code))
|
||||
}
|
||||
|
||||
// Code should be numeric
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("GenerateCode() returned non-numeric code: %s", code)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_ValidateCode(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Generate a code for current time
|
||||
now := time.Now()
|
||||
code, err := totp.GenerateCode(secret, now)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate the code
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for current code")
|
||||
}
|
||||
|
||||
// Test with invalid code
|
||||
valid, err = totp.ValidateCode(secret, "000000")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// This might occasionally pass if 000000 is the correct code, but very unlikely
|
||||
if valid && code != "000000" {
|
||||
t.Error("ValidateCode() = true for invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_ValidateCode_WithSkew(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 2, // Allow 2 periods before/after
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Generate code for 1 period ago
|
||||
past := time.Now().Add(-30 * time.Second)
|
||||
code, err := totp.GenerateCode(secret, past)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Should still validate with skew window
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for code within skew window")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_DifferentAlgorithms(t *testing.T) {
|
||||
algorithms := []string{"SHA1", "SHA256", "SHA512"}
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
for _, algo := range algorithms {
|
||||
t.Run(algo, func(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: algo,
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
code, err := totp.GenerateCode(secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() with %s error = %v", algo, err)
|
||||
}
|
||||
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() with %s error = %v", algo, err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Errorf("ValidateCode() with %s = false, want true", algo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_8Digits(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 8,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
code, err := totp.GenerateCode(secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if len(code) != 8 {
|
||||
t.Errorf("GenerateCode() returned code with length %d, want 8", len(code))
|
||||
}
|
||||
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for 8-digit code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateBackupCodes(t *testing.T) {
|
||||
count := 10
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(codes) != count {
|
||||
t.Errorf("GenerateBackupCodes() returned %d codes, want %d", len(codes), count)
|
||||
}
|
||||
|
||||
// Check uniqueness
|
||||
seen := make(map[string]bool)
|
||||
for _, code := range codes {
|
||||
if seen[code] {
|
||||
t.Errorf("GenerateBackupCodes() generated duplicate code: %s", code)
|
||||
}
|
||||
seen[code] = true
|
||||
|
||||
// Check format (8 hex characters)
|
||||
if len(code) != 8 {
|
||||
t.Errorf("GenerateBackupCodes() code length = %d, want 8", len(code))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTwoFactorConfig(t *testing.T) {
|
||||
config := DefaultTwoFactorConfig()
|
||||
|
||||
if config.Algorithm != "SHA1" {
|
||||
t.Errorf("DefaultTwoFactorConfig() Algorithm = %s, want SHA1", config.Algorithm)
|
||||
}
|
||||
|
||||
if config.Digits != 6 {
|
||||
t.Errorf("DefaultTwoFactorConfig() Digits = %d, want 6", config.Digits)
|
||||
}
|
||||
|
||||
if config.Period != 30 {
|
||||
t.Errorf("DefaultTwoFactorConfig() Period = %d, want 30", config.Period)
|
||||
}
|
||||
|
||||
if config.SkewWindow != 1 {
|
||||
t.Errorf("DefaultTwoFactorConfig() SkewWindow = %d, want 1", config.SkewWindow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_InvalidSecret(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
// Test with invalid base32 secret
|
||||
_, err := totp.GenerateCode("INVALID!!!", time.Now())
|
||||
if err == nil {
|
||||
t.Error("GenerateCode() with invalid secret should return error")
|
||||
}
|
||||
|
||||
_, err = totp.ValidateCode("INVALID!!!", "123456")
|
||||
if err == nil {
|
||||
t.Error("ValidateCode() with invalid secret should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkTOTPGenerator_GenerateCode(b *testing.B) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
now := time.Now()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = totp.GenerateCode(secret, now)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTOTPGenerator_ValidateCode(b *testing.B) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
code, _ := totp.GenerateCode(secret, time.Now())
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = totp.ValidateCode(secret, code)
|
||||
}
|
||||
}
|
||||
@@ -411,7 +411,9 @@ func newInstance(cfg Config) (*serverInstance, error) {
|
||||
return nil, fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
|
||||
// Set default timeouts
|
||||
// Set default timeouts with minimum of 10 minutes for connection timeouts
|
||||
minConnectionTimeout := 10 * time.Minute
|
||||
|
||||
if cfg.ShutdownTimeout == 0 {
|
||||
cfg.ShutdownTimeout = 30 * time.Second
|
||||
}
|
||||
@@ -419,13 +421,22 @@ func newInstance(cfg Config) (*serverInstance, error) {
|
||||
cfg.DrainTimeout = 25 * time.Second
|
||||
}
|
||||
if cfg.ReadTimeout == 0 {
|
||||
cfg.ReadTimeout = 15 * time.Second
|
||||
cfg.ReadTimeout = minConnectionTimeout
|
||||
} else if cfg.ReadTimeout < minConnectionTimeout {
|
||||
// Enforce minimum of 10 minutes
|
||||
cfg.ReadTimeout = minConnectionTimeout
|
||||
}
|
||||
if cfg.WriteTimeout == 0 {
|
||||
cfg.WriteTimeout = 15 * time.Second
|
||||
cfg.WriteTimeout = minConnectionTimeout
|
||||
} else if cfg.WriteTimeout < minConnectionTimeout {
|
||||
// Enforce minimum of 10 minutes
|
||||
cfg.WriteTimeout = minConnectionTimeout
|
||||
}
|
||||
if cfg.IdleTimeout == 0 {
|
||||
cfg.IdleTimeout = 60 * time.Second
|
||||
cfg.IdleTimeout = minConnectionTimeout
|
||||
} else if cfg.IdleTimeout < minConnectionTimeout {
|
||||
// Enforce minimum of 10 minutes
|
||||
cfg.IdleTimeout = minConnectionTimeout
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
|
||||
@@ -4,6 +4,7 @@ package spectypes
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
@@ -60,7 +61,33 @@ func (n *SqlNull[T]) Scan(value any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try standard sql.Null[T] first.
|
||||
// Check if T is []byte, and decode base64 if applicable
|
||||
// Do this BEFORE trying sql.Null to ensure base64 is handled
|
||||
var zero T
|
||||
if _, ok := any(zero).([]byte); ok {
|
||||
// For []byte types, try to decode from base64
|
||||
var strVal string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
strVal = v
|
||||
case []byte:
|
||||
strVal = string(v)
|
||||
default:
|
||||
strVal = fmt.Sprintf("%v", value)
|
||||
}
|
||||
// Try base64 decode
|
||||
if decoded, err := base64.StdEncoding.DecodeString(strVal); err == nil {
|
||||
n.Val = any(decoded).(T)
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
// Fallback to raw bytes
|
||||
n.Val = any([]byte(strVal)).(T)
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try standard sql.Null[T] for other types.
|
||||
var sqlNull sql.Null[T]
|
||||
if err := sqlNull.Scan(value); err == nil {
|
||||
n.Val = sqlNull.V
|
||||
@@ -74,6 +101,10 @@ func (n *SqlNull[T]) Scan(value any) error {
|
||||
return n.FromString(v)
|
||||
case []byte:
|
||||
return n.FromString(string(v))
|
||||
case float32, float64:
|
||||
return n.FromString(fmt.Sprintf("%f", value))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return n.FromString(fmt.Sprintf("%d", value))
|
||||
default:
|
||||
return n.FromString(fmt.Sprintf("%v", value))
|
||||
}
|
||||
@@ -94,6 +125,10 @@ func (n *SqlNull[T]) FromString(s string) error {
|
||||
reflect.ValueOf(&n.Val).Elem().SetInt(i)
|
||||
n.Valid = true
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
reflect.ValueOf(&n.Val).Elem().SetInt(int64(f))
|
||||
n.Valid = true
|
||||
}
|
||||
case float32, float64:
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
reflect.ValueOf(&n.Val).Elem().SetFloat(f)
|
||||
@@ -114,6 +149,9 @@ func (n *SqlNull[T]) FromString(s string) error {
|
||||
n.Val = any(u).(T)
|
||||
n.Valid = true
|
||||
}
|
||||
case []byte:
|
||||
n.Val = any([]byte(s)).(T)
|
||||
n.Valid = true
|
||||
case string:
|
||||
n.Val = any(s).(T)
|
||||
n.Valid = true
|
||||
@@ -141,6 +179,14 @@ func (n SqlNull[T]) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// Check if T is []byte, and encode to base64
|
||||
if _, ok := any(n.Val).([]byte); ok {
|
||||
// Encode []byte as base64
|
||||
encoded := base64.StdEncoding.EncodeToString(any(n.Val).([]byte))
|
||||
return json.Marshal(encoded)
|
||||
}
|
||||
|
||||
return json.Marshal(n.Val)
|
||||
}
|
||||
|
||||
@@ -152,8 +198,25 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try direct unmarshal.
|
||||
// Check if T is []byte, and decode from base64
|
||||
var val T
|
||||
if _, ok := any(val).([]byte); ok {
|
||||
// Unmarshal as string first (JSON representation)
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err == nil {
|
||||
// Decode from base64
|
||||
if decoded, err := base64.StdEncoding.DecodeString(s); err == nil {
|
||||
n.Val = any(decoded).(T)
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
// Fallback to raw string as bytes
|
||||
n.Val = any([]byte(s)).(T)
|
||||
n.Valid = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, &val); err == nil {
|
||||
n.Val = val
|
||||
n.Valid = true
|
||||
@@ -263,13 +326,14 @@ func (n SqlNull[T]) UUID() uuid.UUID {
|
||||
|
||||
// Type aliases for common types.
|
||||
type (
|
||||
SqlInt16 = SqlNull[int16]
|
||||
SqlInt32 = SqlNull[int32]
|
||||
SqlInt64 = SqlNull[int64]
|
||||
SqlFloat64 = SqlNull[float64]
|
||||
SqlBool = SqlNull[bool]
|
||||
SqlString = SqlNull[string]
|
||||
SqlUUID = SqlNull[uuid.UUID]
|
||||
SqlInt16 = SqlNull[int16]
|
||||
SqlInt32 = SqlNull[int32]
|
||||
SqlInt64 = SqlNull[int64]
|
||||
SqlFloat64 = SqlNull[float64]
|
||||
SqlBool = SqlNull[bool]
|
||||
SqlString = SqlNull[string]
|
||||
SqlByteArray = SqlNull[[]byte]
|
||||
SqlUUID = SqlNull[uuid.UUID]
|
||||
)
|
||||
|
||||
// SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS).
|
||||
@@ -573,6 +637,10 @@ func NewSqlString(v string) SqlString {
|
||||
return SqlString{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlByteArray(v []byte) SqlByteArray {
|
||||
return SqlByteArray{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
func NewSqlUUID(v uuid.UUID) SqlUUID {
|
||||
return SqlUUID{Val: v, Valid: true}
|
||||
}
|
||||
|
||||
@@ -565,3 +565,394 @@ func TestTryIfInt64(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlString tests SqlString without base64 (plain text)
|
||||
func TestSqlString_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "plain string",
|
||||
input: "hello world",
|
||||
expected: "hello world",
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "plain text",
|
||||
input: "plain text",
|
||||
expected: "plain text",
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "bytes as string",
|
||||
input: []byte("raw bytes"),
|
||||
expected: "raw bytes",
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "nil value",
|
||||
input: nil,
|
||||
expected: "",
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var s SqlString
|
||||
if err := s.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if s.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, s.Valid)
|
||||
}
|
||||
if tt.valid && s.String() != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, s.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlString_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputValue string
|
||||
expectedJSON string
|
||||
expectedDecode string
|
||||
}{
|
||||
{
|
||||
name: "simple string",
|
||||
inputValue: "hello world",
|
||||
expectedJSON: `"hello world"`, // plain text, not base64
|
||||
expectedDecode: "hello world",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
inputValue: "test@#$%",
|
||||
expectedJSON: `"test@#$%"`, // plain text, not base64
|
||||
expectedDecode: "test@#$%",
|
||||
},
|
||||
{
|
||||
name: "unicode string",
|
||||
inputValue: "Hello 世界",
|
||||
expectedJSON: `"Hello 世界"`, // plain text, not base64
|
||||
expectedDecode: "Hello 世界",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
inputValue: "",
|
||||
expectedJSON: `""`,
|
||||
expectedDecode: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test MarshalJSON
|
||||
s := NewSqlString(tt.inputValue)
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != tt.expectedJSON {
|
||||
t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data))
|
||||
}
|
||||
|
||||
// Test UnmarshalJSON
|
||||
var s2 SqlString
|
||||
if err := json.Unmarshal(data, &s2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if !s2.Valid {
|
||||
t.Error("expected valid=true after unmarshal")
|
||||
}
|
||||
if s2.String() != tt.expectedDecode {
|
||||
t.Errorf("Unmarshal: expected %q, got %q", tt.expectedDecode, s2.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlString_JSON_Null(t *testing.T) {
|
||||
// Test null handling
|
||||
var s SqlString
|
||||
if err := json.Unmarshal([]byte("null"), &s); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if s.Valid {
|
||||
t.Error("expected invalid after unmarshaling null")
|
||||
}
|
||||
|
||||
// Test marshal null
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlByteArray_Base64 tests SqlByteArray with base64 encoding/decoding
|
||||
func TestSqlByteArray_Base64_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected []byte
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "base64 encoded bytes from SQL",
|
||||
input: "aGVsbG8gd29ybGQ=", // "hello world" in base64
|
||||
expected: []byte("hello world"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "plain bytes fallback",
|
||||
input: "plain text",
|
||||
expected: []byte("plain text"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "bytes base64 encoded",
|
||||
input: []byte("SGVsbG8gR29waGVy"), // "Hello Gopher" in base64
|
||||
expected: []byte("Hello Gopher"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "bytes plain fallback",
|
||||
input: []byte("raw bytes"),
|
||||
expected: []byte("raw bytes"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "binary data",
|
||||
input: "AQIDBA==", // []byte{1, 2, 3, 4} in base64
|
||||
expected: []byte{1, 2, 3, 4},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "nil value",
|
||||
input: nil,
|
||||
expected: nil,
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var b SqlByteArray
|
||||
if err := b.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if b.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, b.Valid)
|
||||
}
|
||||
if tt.valid {
|
||||
if string(b.Val) != string(tt.expected) {
|
||||
t.Errorf("expected %q, got %q", tt.expected, b.Val)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlByteArray_Base64_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputValue []byte
|
||||
expectedJSON string
|
||||
expectedDecode []byte
|
||||
}{
|
||||
{
|
||||
name: "text bytes",
|
||||
inputValue: []byte("hello world"),
|
||||
expectedJSON: `"aGVsbG8gd29ybGQ="`, // base64 encoded
|
||||
expectedDecode: []byte("hello world"),
|
||||
},
|
||||
{
|
||||
name: "binary data",
|
||||
inputValue: []byte{0x01, 0x02, 0x03, 0x04, 0xFF},
|
||||
expectedJSON: `"AQIDBP8="`, // base64 encoded
|
||||
expectedDecode: []byte{0x01, 0x02, 0x03, 0x04, 0xFF},
|
||||
},
|
||||
{
|
||||
name: "empty bytes",
|
||||
inputValue: []byte{},
|
||||
expectedJSON: `""`, // base64 of empty bytes
|
||||
expectedDecode: []byte{},
|
||||
},
|
||||
{
|
||||
name: "unicode bytes",
|
||||
inputValue: []byte("Hello 世界"),
|
||||
expectedJSON: `"SGVsbG8g5LiW55WM"`, // base64 encoded
|
||||
expectedDecode: []byte("Hello 世界"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test MarshalJSON
|
||||
b := NewSqlByteArray(tt.inputValue)
|
||||
data, err := json.Marshal(b)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != tt.expectedJSON {
|
||||
t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data))
|
||||
}
|
||||
|
||||
// Test UnmarshalJSON
|
||||
var b2 SqlByteArray
|
||||
if err := json.Unmarshal(data, &b2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if !b2.Valid {
|
||||
t.Error("expected valid=true after unmarshal")
|
||||
}
|
||||
if string(b2.Val) != string(tt.expectedDecode) {
|
||||
t.Errorf("Unmarshal: expected %v, got %v", tt.expectedDecode, b2.Val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlByteArray_Base64_JSON_Null(t *testing.T) {
|
||||
// Test null handling
|
||||
var b SqlByteArray
|
||||
if err := json.Unmarshal([]byte("null"), &b); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if b.Valid {
|
||||
t.Error("expected invalid after unmarshaling null")
|
||||
}
|
||||
|
||||
// Test marshal null
|
||||
data, err := json.Marshal(b)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlByteArray_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlByteArray
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "valid bytes",
|
||||
input: NewSqlByteArray([]byte("test data")),
|
||||
expected: []byte("test data"),
|
||||
},
|
||||
{
|
||||
name: "empty bytes",
|
||||
input: NewSqlByteArray([]byte{}),
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: SqlByteArray{Valid: false},
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if tt.expected == nil && val != nil {
|
||||
t.Errorf("expected nil, got %v", val)
|
||||
}
|
||||
if tt.expected != nil && val == nil {
|
||||
t.Errorf("expected %v, got nil", tt.expected)
|
||||
}
|
||||
if tt.expected != nil && val != nil {
|
||||
if string(val.([]byte)) != string(tt.expected.([]byte)) {
|
||||
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlString_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go
|
||||
func TestSqlString_RoundTrip(t *testing.T) {
|
||||
original := "Test String with Special Chars: @#$%^&*()"
|
||||
|
||||
// Go -> JSON
|
||||
s1 := NewSqlString(original)
|
||||
jsonData, err := json.Marshal(s1)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON -> Go
|
||||
var s2 SqlString
|
||||
if err := json.Unmarshal(jsonData, &s2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Go -> SQL (Value)
|
||||
_, err = s2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
|
||||
// SQL -> Go (Scan plain text)
|
||||
var s3 SqlString
|
||||
// Simulate SQL driver returning plain text value
|
||||
if err := s3.Scan(original); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify round-trip
|
||||
if s3.String() != original {
|
||||
t.Errorf("Round-trip failed: expected %q, got %q", original, s3.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlByteArray_Base64_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go
|
||||
func TestSqlByteArray_Base64_RoundTrip(t *testing.T) {
|
||||
original := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0xFF, 0xFE} // "Hello " + binary data
|
||||
|
||||
// Go -> JSON
|
||||
b1 := NewSqlByteArray(original)
|
||||
jsonData, err := json.Marshal(b1)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON -> Go
|
||||
var b2 SqlByteArray
|
||||
if err := json.Unmarshal(jsonData, &b2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Go -> SQL (Value)
|
||||
_, err = b2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
|
||||
// SQL -> Go (Scan with base64)
|
||||
var b3 SqlByteArray
|
||||
// Simulate SQL driver returning base64 encoded value
|
||||
if err := b3.Scan("SGVsbG8g//4="); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify round-trip
|
||||
if string(b3.Val) != string(original) {
|
||||
t.Errorf("Round-trip failed: expected %v, got %v", original, b3.Val)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -209,10 +210,14 @@ func (h *Handler) handleRead(conn *Connection, msg *Message, hookCtx *HookContex
|
||||
var metadata map[string]interface{}
|
||||
var err error
|
||||
|
||||
if hookCtx.ID != "" {
|
||||
// Read single record by ID
|
||||
// Check if FetchRowNumber is specified (treat as single record read)
|
||||
isFetchRowNumber := hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != ""
|
||||
|
||||
if hookCtx.ID != "" || isFetchRowNumber {
|
||||
// Read single record by ID or FetchRowNumber
|
||||
data, err = h.readByID(hookCtx)
|
||||
metadata = map[string]interface{}{"total": 1}
|
||||
// The row number is already set on the record itself via setRowNumbersOnRecords
|
||||
} else {
|
||||
// Read multiple records
|
||||
data, metadata, err = h.readMultiple(hookCtx)
|
||||
@@ -509,10 +514,29 @@ func (h *Handler) notifySubscribers(schema, entity string, operation OperationTy
|
||||
// CRUD operation implementations
|
||||
|
||||
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
// Handle FetchRowNumber before building query
|
||||
var fetchedRowNumber *int64
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
|
||||
if hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != "" {
|
||||
fetchRowNumberPKValue := *hookCtx.Options.FetchRowNumber
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(hookCtx.Context, hookCtx.TableName, pkName, fetchRowNumberPKValue, hookCtx.Options, hookCtx.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
fetchedRowNumber = &rowNum
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||
|
||||
// Override ID with FetchRowNumber value
|
||||
hookCtx.ID = fetchRowNumberPKValue
|
||||
}
|
||||
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
// Apply columns
|
||||
@@ -532,6 +556,12 @@ func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to read record: %w", err)
|
||||
}
|
||||
|
||||
// Set the fetched row number on the record if FetchRowNumber was used
|
||||
if fetchedRowNumber != nil {
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||
h.setRowNumbersOnRecords(hookCtx.ModelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
@@ -540,10 +570,8 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
|
||||
// Apply options (simplified implementation)
|
||||
if hookCtx.Options != nil {
|
||||
// Apply filters
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
// Apply filters with OR grouping support
|
||||
query = h.applyFilters(query, hookCtx.Options.Filters)
|
||||
|
||||
// Apply sorting
|
||||
for _, sort := range hookCtx.Options.Sort {
|
||||
@@ -578,6 +606,13 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
return nil, nil, fmt.Errorf("failed to read records: %w", err)
|
||||
}
|
||||
|
||||
// Set row numbers on records if RowNumber field exists
|
||||
offset := 0
|
||||
if hookCtx.Options != nil && hookCtx.Options.Offset != nil {
|
||||
offset = *hookCtx.Options.Offset
|
||||
}
|
||||
h.setRowNumbersOnRecords(hookCtx.ModelPtr, offset)
|
||||
|
||||
// Get count
|
||||
metadata = make(map[string]interface{})
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
@@ -656,11 +691,14 @@ func (h *Handler) delete(hookCtx *HookContext) error {
|
||||
// Helper methods
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
tableName = schema + "_" + tableName
|
||||
} else {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
@@ -680,6 +718,133 @@ func (h *Handler) getMetadata(schema, entity string, model interface{}) map[stri
|
||||
}
|
||||
|
||||
// getOperatorSQL converts filter operator to SQL operator
|
||||
// applyFilters applies all filters with proper grouping for OR logic
|
||||
// Groups consecutive OR filters together to ensure proper query precedence
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped WHERE clause
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build all conditions and collect args
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Single filter - no need for grouping
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
|
||||
// Multiple conditions - group with parentheses and OR
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
|
||||
args = []interface{}{filter.Value}
|
||||
|
||||
return condition, args
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("[WebSocketSpec] setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("[WebSocketSpec] Set RowNumber=%d for record index %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) getOperatorSQL(operator string) string {
|
||||
switch operator {
|
||||
case "eq":
|
||||
@@ -705,6 +870,92 @@ func (h *Handler) getOperatorSQL(operator string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// FetchRowNumber calculates the row number of a specific record based on sorting and filtering
|
||||
// Returns the 1-based row number of the record with the given primary key value
|
||||
func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options *common.RequestOptions, model interface{}) (int64, error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("[WebSocketSpec] Panic during FetchRowNumber: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Build the sort order SQL
|
||||
sortSQL := ""
|
||||
if options != nil && len(options.Sort) > 0 {
|
||||
sortParts := make([]string, 0, len(options.Sort))
|
||||
for _, sort := range options.Sort {
|
||||
if sort.Column == "" {
|
||||
continue
|
||||
}
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
sortSQL = strings.Join(sortParts, ", ")
|
||||
} else {
|
||||
// Default sort by primary key
|
||||
sortSQL = fmt.Sprintf("%s ASC", pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clause from filters
|
||||
whereSQL := ""
|
||||
var whereArgs []interface{}
|
||||
if options != nil && len(options.Filters) > 0 {
|
||||
var conditions []string
|
||||
for _, filter := range options.Filters {
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
conditions = append(conditions, fmt.Sprintf("%s.%s %s ?", tableName, filter.Column, operatorSQL))
|
||||
whereArgs = append(whereArgs, filter.Value)
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
}
|
||||
|
||||
// Build the final query with parameterized PK value
|
||||
queryStr := fmt.Sprintf(`
|
||||
SELECT search.rn
|
||||
FROM (
|
||||
SELECT %[1]s.%[2]s,
|
||||
ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn
|
||||
FROM %[1]s
|
||||
%[4]s
|
||||
) search
|
||||
WHERE search.%[2]s = ?
|
||||
`,
|
||||
tableName, // [1] - table name
|
||||
pkName, // [2] - primary key column name
|
||||
sortSQL, // [3] - sort order SQL
|
||||
whereSQL, // [4] - WHERE clause
|
||||
)
|
||||
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue)
|
||||
|
||||
// Append PK value to whereArgs
|
||||
whereArgs = append(whereArgs, pkValue)
|
||||
|
||||
// Execute the raw query with parameterized PK value
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
err := h.db.Query(ctx, &result, queryStr, whereArgs...)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
whereInfo := "none"
|
||||
if whereSQL != "" {
|
||||
whereInfo = whereSQL
|
||||
}
|
||||
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler
|
||||
func (h *Handler) Shutdown() {
|
||||
h.connManager.Shutdown()
|
||||
|
||||
@@ -82,6 +82,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return args.Get(0)
|
||||
}
|
||||
|
||||
func (m *MockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// MockSelectQuery is a mock implementation of common.SelectQuery
|
||||
type MockSelectQuery struct {
|
||||
mock.Mock
|
||||
|
||||
8
resolvespec-js/.changeset/README.md
Normal file
8
resolvespec-js/.changeset/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Changesets
|
||||
|
||||
Hello and welcome! This folder has been automatically generated by `@changesets/cli`, a build tool that works
|
||||
with multi-package repos, or single-package repos to help you version and publish your code. You can
|
||||
find the full documentation for it [in our repository](https://github.com/changesets/changesets)
|
||||
|
||||
We have a quick list of common questions to get you started engaging with this project in
|
||||
[our documentation](https://github.com/changesets/changesets/blob/main/docs/common-questions.md)
|
||||
11
resolvespec-js/.changeset/config.json
Normal file
11
resolvespec-js/.changeset/config.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"$schema": "https://unpkg.com/@changesets/config@3.1.2/schema.json",
|
||||
"changelog": "@changesets/cli/changelog",
|
||||
"commit": false,
|
||||
"fixed": [],
|
||||
"linked": [],
|
||||
"access": "restricted",
|
||||
"baseBranch": "main",
|
||||
"updateInternalDependencies": "patch",
|
||||
"ignore": []
|
||||
}
|
||||
7
resolvespec-js/CHANGELOG.md
Normal file
7
resolvespec-js/CHANGELOG.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# @warkypublic/resolvespec-js
|
||||
|
||||
## 1.0.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- Fixed headerpsec
|
||||
132
resolvespec-js/PLAN.md
Normal file
132
resolvespec-js/PLAN.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# ResolveSpec JS - Implementation Plan
|
||||
|
||||
TypeScript client library for ResolveSpec, RestHeaderSpec, WebSocket and MQTT APIs.
|
||||
|
||||
---
|
||||
|
||||
## Status
|
||||
|
||||
| Phase | Description | Status |
|
||||
|-------|-------------|--------|
|
||||
| 0 | Restructure into folders | Done |
|
||||
| 1 | Fix types (align with Go) | Done |
|
||||
| 2 | Fix REST client | Done |
|
||||
| 3 | Build config | Done |
|
||||
| 4 | Tests | Done |
|
||||
| 5 | HeaderSpec client | Done |
|
||||
| 6 | MQTT client | Planned |
|
||||
| 6.5 | Unified class pattern + singleton factories | Done |
|
||||
| 7 | Response cache (TTL) | Planned |
|
||||
| 8 | TanStack Query integration | Planned |
|
||||
| 9 | React Hooks | Planned |
|
||||
|
||||
**Build:** `dist/index.js` (ES) + `dist/index.cjs` (CJS) + `.d.ts` declarations
|
||||
**Tests:** 65 passing (common: 10, resolvespec: 13, websocketspec: 15, headerspec: 27)
|
||||
|
||||
---
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── common/
|
||||
│ ├── types.ts # Core types aligned with Go pkg/common/types.go
|
||||
│ └── index.ts
|
||||
├── resolvespec/
|
||||
│ ├── client.ts # ResolveSpecClient class + createResolveSpecClient singleton
|
||||
│ └── index.ts
|
||||
├── headerspec/
|
||||
│ ├── client.ts # HeaderSpecClient class + createHeaderSpecClient singleton + buildHeaders utility
|
||||
│ └── index.ts
|
||||
├── websocketspec/
|
||||
│ ├── types.ts # WS-specific types (WSMessage, WSOptions, etc.)
|
||||
│ ├── client.ts # WebSocketClient class + createWebSocketClient singleton
|
||||
│ └── index.ts
|
||||
├── mqttspec/ # Future
|
||||
│ ├── types.ts
|
||||
│ ├── client.ts
|
||||
│ └── index.ts
|
||||
├── __tests__/
|
||||
│ ├── common.test.ts
|
||||
│ ├── resolvespec.test.ts
|
||||
│ ├── headerspec.test.ts
|
||||
│ └── websocketspec.test.ts
|
||||
└── index.ts # Root barrel export
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Type Alignment with Go
|
||||
|
||||
Types in `src/common/types.ts` match `pkg/common/types.go`:
|
||||
|
||||
- **Operator**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, `contains`, `startswith`, `endswith`, `between`, `between_inclusive`, `is_null`, `is_not_null`
|
||||
- **FilterOption**: `column`, `operator`, `value`, `logic_operator` (AND/OR)
|
||||
- **Options**: `columns`, `omit_columns`, `filters`, `sort`, `limit`, `offset`, `preload`, `customOperators`, `computedColumns`, `parameters`, `cursor_forward`, `cursor_backward`, `fetch_row_number`
|
||||
- **PreloadOption**: `relation`, `table_name`, `columns`, `omit_columns`, `sort`, `filters`, `where`, `limit`, `offset`, `updatable`, `recursive`, `computed_ql`, `primary_key`, `related_key`, `foreign_key`, `recursive_child_key`, `sql_joins`, `join_aliases`
|
||||
- **Parameter**: `name`, `value`, `sequence?`
|
||||
- **Metadata**: `total`, `count`, `filtered`, `limit`, `offset`, `row_number?`
|
||||
- **APIError**: `code`, `message`, `details?`, `detail?`
|
||||
|
||||
---
|
||||
|
||||
## HeaderSpec Header Mapping
|
||||
|
||||
Maps Options to HTTP headers per Go `restheadspec/headers.go`:
|
||||
|
||||
| Header | Options field | Format |
|
||||
|--------|--------------|--------|
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` (asc), `-col` (desc) |
|
||||
| `X-Limit` | `limit` | number |
|
||||
| `X-Offset` | `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
Complex values use `ZIP_` + base64 encoding.
|
||||
HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete.
|
||||
|
||||
---
|
||||
|
||||
## Build & Test
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # vite library mode → dist/
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
**Config files:** `tsconfig.json` (ES2020, strict, bundler), `vite.config.ts` (lib mode, dts via vite-plugin-dts)
|
||||
**Externals:** `uuid`, `semver`
|
||||
|
||||
---
|
||||
|
||||
## Remaining Work
|
||||
|
||||
- **Phase 6 — MQTT Client**: Topic-based CRUD over MQTT (optional/future)
|
||||
- **Phase 7 — Cache**: In-memory response cache with TTL, key = URL + options hash, auto-invalidation on CUD, `skipCache` flag
|
||||
- **Phase 8 — TanStack Query Integration**: Query/mutation hooks wrapping each client, query key factories, automatic cache invalidation
|
||||
- **Phase 9 — React Hooks**: `useResolveSpec`, `useHeaderSpec`, `useWebSocket` hooks with provider context, loading/error states
|
||||
- ESLint config may need updating for new folder structure
|
||||
|
||||
---
|
||||
|
||||
## Reference Files
|
||||
|
||||
| Purpose | Path |
|
||||
|---------|------|
|
||||
| Go types (source of truth) | `pkg/common/types.go` |
|
||||
| Go REST handler | `pkg/resolvespec/handler.go` |
|
||||
| Go HeaderSpec handler | `pkg/restheadspec/handler.go` |
|
||||
| Go HeaderSpec header parsing | `pkg/restheadspec/headers.go` |
|
||||
| Go test models | `pkg/testmodels/business.go` |
|
||||
| Go tests | `tests/crud_test.go` |
|
||||
213
resolvespec-js/README.md
Normal file
213
resolvespec-js/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# ResolveSpec JS
|
||||
|
||||
TypeScript client library for ResolveSpec APIs. Supports body-based REST, header-based REST, and WebSocket protocols.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Clients
|
||||
|
||||
| Client | Protocol | Singleton Factory |
|
||||
| --- | --- | --- |
|
||||
| `ResolveSpecClient` | REST (body-based) | `getResolveSpecClient(config)` |
|
||||
| `HeaderSpecClient` | REST (header-based) | `getHeaderSpecClient(config)` |
|
||||
| `WebSocketClient` | WebSocket | `getWebSocketClient(config)` |
|
||||
|
||||
All clients use the class pattern. Singleton factories return cached instances keyed by URL.
|
||||
|
||||
## REST Client (Body-Based)
|
||||
|
||||
Options sent in JSON request body. Maps to Go `pkg/resolvespec`.
|
||||
|
||||
```typescript
|
||||
import { ResolveSpecClient, getResolveSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Class instantiation
|
||||
const client = new ResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Or singleton factory (returns cached instance per baseUrl)
|
||||
const client = getResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Read with filters, sort, pagination
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name', 'email'],
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
preload: [{ relation: 'Posts', columns: ['id', 'title'] }],
|
||||
});
|
||||
|
||||
// Read by ID
|
||||
const user = await client.read('public', 'users', 42);
|
||||
|
||||
// Create
|
||||
const created = await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// Update
|
||||
await client.update('public', 'users', { name: 'Updated' }, 42);
|
||||
|
||||
// Delete
|
||||
await client.delete('public', 'users', 42);
|
||||
|
||||
// Metadata
|
||||
const meta = await client.getMetadata('public', 'users');
|
||||
```
|
||||
|
||||
## HeaderSpec Client (Header-Based)
|
||||
|
||||
Options sent via HTTP headers. Maps to Go `pkg/restheadspec`.
|
||||
|
||||
```typescript
|
||||
import { HeaderSpecClient, getHeaderSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const client = new HeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
// Or: const client = getHeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// GET with options as headers
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' },
|
||||
{ column: 'age', operator: 'gte', value: 18, logic_operator: 'AND' },
|
||||
],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 50,
|
||||
preload: [{ relation: 'Department', columns: ['id', 'name'] }],
|
||||
});
|
||||
|
||||
// POST create
|
||||
await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// PUT update
|
||||
await client.update('public', 'users', '42', { name: 'Updated' });
|
||||
|
||||
// DELETE
|
||||
await client.delete('public', 'users', '42');
|
||||
```
|
||||
|
||||
### Header Mapping
|
||||
|
||||
| Header | Options Field | Format |
|
||||
| --- | --- | --- |
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` asc, `-col` desc |
|
||||
| `X-Limit` / `X-Offset` | `limit` / `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```typescript
|
||||
import { buildHeaders, encodeHeaderValue, decodeHeaderValue } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const headers = buildHeaders({ columns: ['id', 'name'], limit: 10 });
|
||||
// => { 'X-Select-Fields': 'id,name', 'X-Limit': '10' }
|
||||
|
||||
const encoded = encodeHeaderValue('complex value'); // 'ZIP_...'
|
||||
const decoded = decodeHeaderValue(encoded); // 'complex value'
|
||||
```
|
||||
|
||||
## WebSocket Client
|
||||
|
||||
Real-time CRUD with subscriptions. Maps to Go `pkg/websocketspec`.
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient, getWebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const ws = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
heartbeatInterval: 30000,
|
||||
});
|
||||
// Or: const ws = getWebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
|
||||
await ws.connect();
|
||||
|
||||
// CRUD
|
||||
const users = await ws.read('users', { schema: 'public', limit: 10 });
|
||||
const created = await ws.create('users', { name: 'New' }, { schema: 'public' });
|
||||
await ws.update('users', '1', { name: 'Updated' });
|
||||
await ws.delete('users', '1');
|
||||
|
||||
// Subscribe to changes
|
||||
const subId = await ws.subscribe('users', (notification) => {
|
||||
console.log(notification.operation, notification.data);
|
||||
});
|
||||
|
||||
// Unsubscribe
|
||||
await ws.unsubscribe(subId);
|
||||
|
||||
// Events
|
||||
ws.on('connect', () => console.log('connected'));
|
||||
ws.on('disconnect', () => console.log('disconnected'));
|
||||
ws.on('error', (err) => console.error(err));
|
||||
|
||||
ws.disconnect();
|
||||
```
|
||||
|
||||
## Types
|
||||
|
||||
All types align with Go `pkg/common/types.go`.
|
||||
|
||||
### Key Types
|
||||
|
||||
```typescript
|
||||
interface Options {
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
preload?: PreloadOption[];
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
// Operators: eq, neq, gt, gte, lt, lte, like, ilike, in,
|
||||
// contains, startswith, endswith, between,
|
||||
// between_inclusive, is_null, is_not_null
|
||||
|
||||
interface APIResponse<T> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # dist/index.js (ES) + dist/index.cjs (CJS) + .d.ts
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
@@ -1,530 +0,0 @@
|
||||
# WebSocketSpec JavaScript Client
|
||||
|
||||
A TypeScript/JavaScript client for connecting to WebSocketSpec servers with full support for real-time subscriptions, CRUD operations, and automatic reconnection.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @warkypublic/resolvespec-js
|
||||
# or
|
||||
yarn add @warkypublic/resolvespec-js
|
||||
# or
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Create client
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
debug: true
|
||||
});
|
||||
|
||||
// Connect
|
||||
await client.connect();
|
||||
|
||||
// Read records
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
limit: 10
|
||||
});
|
||||
|
||||
// Subscribe to changes
|
||||
const subscriptionId = await client.subscribe('users', (notification) => {
|
||||
console.log('User changed:', notification.operation, notification.data);
|
||||
}, { schema: 'public' });
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(subscriptionId);
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Real-Time Updates**: Subscribe to entity changes and receive instant notifications
|
||||
- **Full CRUD Support**: Create, read, update, and delete operations
|
||||
- **TypeScript Support**: Full type definitions included
|
||||
- **Auto Reconnection**: Automatic reconnection with configurable retry logic
|
||||
- **Heartbeat**: Built-in keepalive mechanism
|
||||
- **Event System**: Listen to connection, error, and message events
|
||||
- **Promise-based API**: All async operations return promises
|
||||
- **Filter & Sort**: Advanced querying with filters, sorting, and pagination
|
||||
- **Preloading**: Load related entities in a single query
|
||||
|
||||
## Configuration
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws', // WebSocket server URL
|
||||
reconnect: true, // Enable auto-reconnection
|
||||
reconnectInterval: 3000, // Reconnection delay (ms)
|
||||
maxReconnectAttempts: 10, // Max reconnection attempts
|
||||
heartbeatInterval: 30000, // Heartbeat interval (ms)
|
||||
debug: false // Enable debug logging
|
||||
});
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Connection Management
|
||||
|
||||
#### `connect(): Promise<void>`
|
||||
Connect to the WebSocket server.
|
||||
|
||||
```typescript
|
||||
await client.connect();
|
||||
```
|
||||
|
||||
#### `disconnect(): void`
|
||||
Disconnect from the server.
|
||||
|
||||
```typescript
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
#### `isConnected(): boolean`
|
||||
Check if currently connected.
|
||||
|
||||
```typescript
|
||||
if (client.isConnected()) {
|
||||
console.log('Connected!');
|
||||
}
|
||||
```
|
||||
|
||||
#### `getState(): ConnectionState`
|
||||
Get current connection state: `'connecting'`, `'connected'`, `'disconnecting'`, `'disconnected'`, or `'reconnecting'`.
|
||||
|
||||
```typescript
|
||||
const state = client.getState();
|
||||
console.log('State:', state);
|
||||
```
|
||||
|
||||
### CRUD Operations
|
||||
|
||||
#### `read<T>(entity: string, options?): Promise<T>`
|
||||
Read records from an entity.
|
||||
|
||||
```typescript
|
||||
// Read all active users
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
columns: ['id', 'name', 'email'],
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' }
|
||||
],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
|
||||
// Read single record by ID
|
||||
const user = await client.read('users', {
|
||||
schema: 'public',
|
||||
record_id: '123'
|
||||
});
|
||||
|
||||
// Read with preloading
|
||||
const posts = await client.read('posts', {
|
||||
schema: 'public',
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
},
|
||||
{
|
||||
relation: 'comments',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'approved' }
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
#### `create<T>(entity: string, data: any, options?): Promise<T>`
|
||||
Create a new record.
|
||||
|
||||
```typescript
|
||||
const newUser = await client.create('users', {
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
status: 'active'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `update<T>(entity: string, id: string, data: any, options?): Promise<T>`
|
||||
Update an existing record.
|
||||
|
||||
```typescript
|
||||
const updatedUser = await client.update('users', '123', {
|
||||
name: 'John Updated',
|
||||
email: 'john.new@example.com'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `delete(entity: string, id: string, options?): Promise<void>`
|
||||
Delete a record.
|
||||
|
||||
```typescript
|
||||
await client.delete('users', '123', {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `meta<T>(entity: string, options?): Promise<T>`
|
||||
Get metadata for an entity.
|
||||
|
||||
```typescript
|
||||
const metadata = await client.meta('users', {
|
||||
schema: 'public'
|
||||
});
|
||||
console.log('Columns:', metadata.columns);
|
||||
console.log('Primary key:', metadata.primary_key);
|
||||
```
|
||||
|
||||
### Subscriptions
|
||||
|
||||
#### `subscribe(entity: string, callback: Function, options?): Promise<string>`
|
||||
Subscribe to entity changes.
|
||||
|
||||
```typescript
|
||||
const subscriptionId = await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
console.log('Operation:', notification.operation); // 'create', 'update', or 'delete'
|
||||
console.log('Data:', notification.data);
|
||||
console.log('Timestamp:', notification.timestamp);
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
#### `unsubscribe(subscriptionId: string): Promise<void>`
|
||||
Unsubscribe from entity changes.
|
||||
|
||||
```typescript
|
||||
await client.unsubscribe(subscriptionId);
|
||||
```
|
||||
|
||||
#### `getSubscriptions(): Subscription[]`
|
||||
Get list of active subscriptions.
|
||||
|
||||
```typescript
|
||||
const subscriptions = client.getSubscriptions();
|
||||
console.log('Active subscriptions:', subscriptions.length);
|
||||
```
|
||||
|
||||
### Event Handling
|
||||
|
||||
#### `on(event: string, callback: Function): void`
|
||||
Add event listener.
|
||||
|
||||
```typescript
|
||||
// Connection events
|
||||
client.on('connect', () => {
|
||||
console.log('Connected!');
|
||||
});
|
||||
|
||||
client.on('disconnect', (event) => {
|
||||
console.log('Disconnected:', event.code, event.reason);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
// State changes
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
});
|
||||
|
||||
// All messages
|
||||
client.on('message', (message) => {
|
||||
console.log('Message:', message);
|
||||
});
|
||||
```
|
||||
|
||||
#### `off(event: string): void`
|
||||
Remove event listener.
|
||||
|
||||
```typescript
|
||||
client.off('connect');
|
||||
```
|
||||
|
||||
## Filter Operators
|
||||
|
||||
- `eq` - Equal (=)
|
||||
- `neq` - Not Equal (!=)
|
||||
- `gt` - Greater Than (>)
|
||||
- `gte` - Greater Than or Equal (>=)
|
||||
- `lt` - Less Than (<)
|
||||
- `lte` - Less Than or Equal (<=)
|
||||
- `like` - LIKE (case-sensitive)
|
||||
- `ilike` - ILIKE (case-insensitive)
|
||||
- `in` - IN (array of values)
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic CRUD
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Create
|
||||
const user = await client.create('users', {
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com'
|
||||
});
|
||||
|
||||
// Read
|
||||
const users = await client.read('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
// Update
|
||||
await client.update('users', user.id, { name: 'Alice Updated' });
|
||||
|
||||
// Delete
|
||||
await client.delete('users', user.id);
|
||||
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
### Real-Time Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to all user changes
|
||||
const subId = await client.subscribe('users', (notification) => {
|
||||
switch (notification.operation) {
|
||||
case 'create':
|
||||
console.log('New user:', notification.data);
|
||||
break;
|
||||
case 'update':
|
||||
console.log('User updated:', notification.data);
|
||||
break;
|
||||
case 'delete':
|
||||
console.log('User deleted:', notification.data);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Later: unsubscribe
|
||||
await client.unsubscribe(subId);
|
||||
```
|
||||
|
||||
### React Integration
|
||||
|
||||
```typescript
|
||||
import { useEffect, useState } from 'react';
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
function useWebSocket(url: string) {
|
||||
const [client] = useState(() => new WebSocketClient({ url }));
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
client.on('connect', () => setIsConnected(true));
|
||||
client.on('disconnect', () => setIsConnected(false));
|
||||
client.connect();
|
||||
|
||||
return () => client.disconnect();
|
||||
}, [client]);
|
||||
|
||||
return { client, isConnected };
|
||||
}
|
||||
|
||||
function UsersComponent() {
|
||||
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
|
||||
const [users, setUsers] = useState([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return;
|
||||
|
||||
const loadUsers = async () => {
|
||||
// Subscribe to changes
|
||||
await client.subscribe('users', (notification) => {
|
||||
if (notification.operation === 'create') {
|
||||
setUsers(prev => [...prev, notification.data]);
|
||||
} else if (notification.operation === 'update') {
|
||||
setUsers(prev => prev.map(u =>
|
||||
u.id === notification.data.id ? notification.data : u
|
||||
));
|
||||
} else if (notification.operation === 'delete') {
|
||||
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
|
||||
}
|
||||
});
|
||||
|
||||
// Load initial data
|
||||
const data = await client.read('users');
|
||||
setUsers(data);
|
||||
};
|
||||
|
||||
loadUsers();
|
||||
}, [client, isConnected]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Users {isConnected ? '🟢' : '🔴'}</h2>
|
||||
{users.map(user => (
|
||||
<div key={user.id}>{user.name}</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### TypeScript with Typed Models
|
||||
|
||||
```typescript
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
status: 'active' | 'inactive';
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: number;
|
||||
title: string;
|
||||
content: string;
|
||||
user_id: number;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Type-safe operations
|
||||
const users = await client.read<User[]>('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
const newUser = await client.create<User>('users', {
|
||||
name: 'Bob',
|
||||
email: 'bob@example.com',
|
||||
status: 'active'
|
||||
});
|
||||
|
||||
// Type-safe subscriptions
|
||||
await client.subscribe(
|
||||
'posts',
|
||||
(notification) => {
|
||||
const post = notification.data as Post;
|
||||
console.log('Post:', post.title);
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
maxReconnectAttempts: 5
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Connection error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
if (state === 'reconnecting') {
|
||||
console.log('Attempting to reconnect...');
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
|
||||
try {
|
||||
const user = await client.read('users', { record_id: '999' });
|
||||
} catch (error) {
|
||||
console.error('Record not found:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
await client.create('users', { /* invalid data */ });
|
||||
} catch (error) {
|
||||
console.error('Validation failed:', error);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Connection failed:', error);
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to multiple entities
|
||||
const userSub = await client.subscribe('users', (n) => {
|
||||
console.log('[Users]', n.operation, n.data);
|
||||
});
|
||||
|
||||
const postSub = await client.subscribe('posts', (n) => {
|
||||
console.log('[Posts]', n.operation, n.data);
|
||||
}, {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'published' }]
|
||||
});
|
||||
|
||||
const commentSub = await client.subscribe('comments', (n) => {
|
||||
console.log('[Comments]', n.operation, n.data);
|
||||
});
|
||||
|
||||
// Check active subscriptions
|
||||
console.log('Active:', client.getSubscriptions().length);
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(userSub);
|
||||
await client.unsubscribe(postSub);
|
||||
await client.unsubscribe(commentSub);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always Clean Up**: Call `disconnect()` when done to close the connection properly
|
||||
2. **Use TypeScript**: Leverage type definitions for better type safety
|
||||
3. **Handle Errors**: Always wrap operations in try-catch blocks
|
||||
4. **Limit Subscriptions**: Don't create too many subscriptions per connection
|
||||
5. **Use Filters**: Apply filters to subscriptions to reduce unnecessary notifications
|
||||
6. **Connection State**: Check `isConnected()` before operations
|
||||
7. **Event Listeners**: Remove event listeners when no longer needed with `off()`
|
||||
8. **Reconnection**: Enable auto-reconnection for production apps
|
||||
|
||||
## Browser Support
|
||||
|
||||
- Chrome/Edge 88+
|
||||
- Firefox 85+
|
||||
- Safari 14+
|
||||
- Node.js 14.16+
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
1
resolvespec-js/dist/index.cjs
vendored
Normal file
1
resolvespec-js/dist/index.cjs
vendored
Normal file
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user