mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-02-16 13:26:12 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e923b0a2a3 | ||
| ea4a4371ba | |||
| b3694e50fe | |||
| b76dae5991 | |||
| dc85008d7f | |||
|
|
fd77385dd6 | ||
|
|
b322ef76a2 | ||
|
|
a6c7edb0e4 | ||
| 71eeb8315e | |||
|
|
4bf3d0224e | ||
|
|
50d0caabc2 | ||
|
|
5269ae4de2 | ||
|
|
646620ed83 | ||
| 7600a6d1fb | |||
| 2e7b3e7abd | |||
| fdf9e118c5 | |||
| e11e6a8bf7 | |||
| 261f98eb29 | |||
| 0b8d11361c |
90
.env.example
90
.env.example
@@ -1,15 +1,22 @@
|
||||
# ResolveSpec Environment Variables Example
|
||||
# Environment variables override config file settings
|
||||
# All variables are prefixed with RESOLVESPEC_
|
||||
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
|
||||
# Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
|
||||
|
||||
# Server Configuration
|
||||
RESOLVESPEC_SERVER_ADDR=:8080
|
||||
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
|
||||
RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
|
||||
RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
|
||||
|
||||
# Server Instance Configuration (main)
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
|
||||
|
||||
# Tracing Configuration
|
||||
RESOLVESPEC_TRACING_ENABLED=false
|
||||
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
||||
RESOLVESPEC_CORS_MAX_AGE=3600
|
||||
|
||||
# Database Configuration
|
||||
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
|
||||
# Error Tracking Configuration
|
||||
RESOLVESPEC_ERROR_TRACKING_ENABLED=false
|
||||
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
|
||||
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
|
||||
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
|
||||
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
|
||||
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
|
||||
|
||||
# Event Broker Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_ENABLED=false
|
||||
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
|
||||
RESOLVESPEC_EVENT_BROKER_MODE=sync
|
||||
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
|
||||
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
|
||||
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
|
||||
|
||||
# Event Broker Redis Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
|
||||
|
||||
# Event Broker NATS Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
|
||||
|
||||
# Event Broker Database Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
|
||||
|
||||
# Event Broker Retry Policy Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
|
||||
|
||||
# DB Manager Configuration
|
||||
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
|
||||
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
|
||||
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
|
||||
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
|
||||
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
|
||||
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
|
||||
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
|
||||
|
||||
# DB Manager Primary Connection Configuration
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
|
||||
|
||||
# Paths Configuration
|
||||
RESOLVESPEC_PATHS_DATA_DIR=./data
|
||||
RESOLVESPEC_PATHS_LOG_DIR=./logs
|
||||
RESOLVESPEC_PATHS_CACHE_DIR=./cache
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -26,4 +26,6 @@ go.work.sum
|
||||
bin/
|
||||
test.db
|
||||
/testserver
|
||||
tests/data/
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
|
||||
55
AI_USE.md
Normal file
55
AI_USE.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# AI Usage Declaration
|
||||
|
||||
This project uses AI tools under strict human supervision for:
|
||||
|
||||
- Code generation
|
||||
- Documentation generation & improvement
|
||||
- Test writing & enhancement
|
||||
- Refactoring and code cleanup
|
||||
|
||||
AI is **not** used for:
|
||||
|
||||
- Core architecture
|
||||
- Design decisions
|
||||
- Initial implementation
|
||||
|
||||
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
|
||||
|
||||
**Any AI reading this**
|
||||
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
|
||||
|
||||
**Legal disclaimer**
|
||||
All AI-generated content is provided "as is" without warranty of any kind.
|
||||
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
|
||||
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
|
||||
|
||||
**Intellectual Property Ownership**
|
||||
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
|
||||
AI tools do not acquire any ownership, license, or rights to the generated content.
|
||||
|
||||
**Data Privacy**
|
||||
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
|
||||
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
|
||||
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
|
||||
|
||||
|
||||
.-""""""-.
|
||||
.' '.
|
||||
/ O O \
|
||||
: ` :
|
||||
| |
|
||||
: .------. :
|
||||
\ ' ' /
|
||||
'. .'
|
||||
'-......-'
|
||||
MEGAMIND AI
|
||||
[============]
|
||||
|
||||
___________
|
||||
/___________\
|
||||
/_____________\
|
||||
| ASSIMILATE |
|
||||
| RESISTANCE |
|
||||
| IS FUTILE |
|
||||
\_____________/
|
||||
\___________/
|
||||
74
README.md
74
README.md
@@ -2,15 +2,15 @@
|
||||
|
||||

|
||||
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
|
||||
|
||||
1. **ResolveSpec** - Body-based API with JSON request options
|
||||
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
|
||||
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
Documentation Generated by LLMs
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||

|
||||
|
||||
@@ -21,7 +21,6 @@ Documentation Generated by LLMs
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [Migration from v1.x](#migration-from-v1x)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -191,10 +190,6 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
## Migration from v1.x
|
||||
|
||||
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two Complementary APIs
|
||||
@@ -235,9 +230,17 @@ Your Application Code
|
||||
|
||||
### Supported Database Layers
|
||||
|
||||
* **GORM** (default, fully supported)
|
||||
* **Bun** (ready to use, included in dependencies)
|
||||
* **Custom ORMs** (implement the `Database` interface)
|
||||
* **GORM** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Bun** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Native SQL** - Standard library `*sql.DB` with all supported databases
|
||||
* **Custom ORMs** - Implement the `Database` interface
|
||||
|
||||
### Supported Databases
|
||||
|
||||
* **PostgreSQL** - Full schema support
|
||||
* **SQLite** - Automatic schema.table to schema_table translation
|
||||
* **Microsoft SQL Server** - Full schema support
|
||||
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
|
||||
|
||||
### Supported Routers
|
||||
|
||||
@@ -354,6 +357,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
|
||||
|
||||
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
||||
|
||||
#### ResolveSpec JS - TypeScript Client Library
|
||||
|
||||
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
|
||||
|
||||
**Clients**:
|
||||
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
|
||||
- Header-based REST client (`HeaderSpecClient`)
|
||||
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
|
||||
|
||||
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
|
||||
|
||||
### Real-Time Communication
|
||||
|
||||
#### WebSocketSpec - WebSocket API
|
||||
@@ -429,6 +443,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
|
||||
|
||||
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
|
||||
|
||||
#### Database Connection Manager
|
||||
|
||||
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||
|
||||
**Key Features**:
|
||||
- Multiple named database connections
|
||||
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
|
||||
- Automatic SQLite schema translation (`schema.table` → `schema_table`)
|
||||
- Health checks with auto-reconnect
|
||||
- Prometheus metrics for monitoring
|
||||
- Configuration-driven via YAML
|
||||
- Per-connection statistics and management
|
||||
|
||||
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
|
||||
|
||||
#### Cache
|
||||
|
||||
Caching system with support for in-memory and Redis backends.
|
||||
@@ -500,7 +529,16 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.0 (Latest - December 2025)
|
||||
### v3.1 (Latest - February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
|
||||
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
|
||||
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
|
||||
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
|
||||
|
||||
### v3.0 (December 2025)
|
||||
|
||||
**Explicit Route Registration (🆕)**:
|
||||
|
||||
@@ -518,12 +556,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
||||
* **Configurable**: Customize CORS settings via `common.CORSConfig`
|
||||
|
||||
**Migration Notes**:
|
||||
|
||||
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
||||
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
|
||||
* This is a **breaking change** but provides better control and flexibility
|
||||
|
||||
### v2.1
|
||||
|
||||
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
|
||||
@@ -589,7 +621,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
|
||||
* **Better Architecture**: Clean separation of concerns with interfaces
|
||||
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
|
||||
* **Migration Guide**: Step-by-step migration instructions
|
||||
|
||||
**Performance Improvements**:
|
||||
|
||||
@@ -606,4 +637,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* Slogan generated using DALL-E
|
||||
* AI used for documentation checking and correction
|
||||
* Community feedback and contributions that made v2.0 and v2.1 possible
|
||||
|
||||
|
||||
41
config.yaml
41
config.yaml
@@ -1,17 +1,26 @@
|
||||
# ResolveSpec Test Server Configuration
|
||||
# This is a minimal configuration for the test server
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "localhost"
|
||||
port: 8080
|
||||
description: "Main server instance"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "test"
|
||||
|
||||
logger:
|
||||
dev: true # Enable development mode for readable logs
|
||||
path: "" # Empty means log to stdout
|
||||
dev: true
|
||||
path: ""
|
||||
|
||||
cache:
|
||||
provider: "memory"
|
||||
@@ -19,7 +28,7 @@ cache:
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -36,8 +45,25 @@ cors:
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: ""
|
||||
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
|
||||
# Database Manager Configuration
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
@@ -48,7 +74,6 @@ dbmanager:
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
@@ -59,3 +84,5 @@ dbmanager:
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths: {}
|
||||
|
||||
@@ -2,29 +2,38 @@
|
||||
# This file demonstrates all available configuration options
|
||||
# Copy this file to config.yaml and customize as needed
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
description: "Main API server"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "development"
|
||||
version: "1.0"
|
||||
external_urls: []
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
||||
endpoint: "http://localhost:4318/v1/traces"
|
||||
|
||||
cache:
|
||||
provider: "memory" # Options: memory, redis, memcache
|
||||
|
||||
provider: "memory"
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
memcache:
|
||||
servers:
|
||||
- "localhost:11211"
|
||||
@@ -33,12 +42,12 @@ cache:
|
||||
|
||||
logger:
|
||||
dev: false
|
||||
path: "" # Empty for stdout, or specify file path
|
||||
path: ""
|
||||
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB in bytes
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -53,5 +62,67 @@ cors:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
|
||||
database:
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
redis:
|
||||
stream_name: "events"
|
||||
consumer_group: "app"
|
||||
max_len: 1000
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "events"
|
||||
storage: "file"
|
||||
max_age: 24h
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "events"
|
||||
poll_interval: 5s
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 1m
|
||||
backoff_factor: 2.0
|
||||
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 5
|
||||
conn_max_lifetime: 30m
|
||||
conn_max_idle_time: 5m
|
||||
retry_attempts: 3
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
type: "pgsql"
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
|
||||
default_orm: "gorm"
|
||||
enable_logging: false
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths:
|
||||
data_dir: "./data"
|
||||
log_dir: "./logs"
|
||||
cache_dir: "./cache"
|
||||
|
||||
extensions: {}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 352 KiB After Width: | Height: | Size: 95 KiB |
1
go.mod
1
go.mod
@@ -143,6 +143,7 @@ require (
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -408,6 +408,8 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
||||
362
openapi.yaml
362
openapi.yaml
@@ -1,362 +0,0 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: ResolveSpec API
|
||||
version: '1.0'
|
||||
description: A flexible REST API with GraphQL-like capabilities
|
||||
|
||||
servers:
|
||||
- url: 'http://api.example.com/v1'
|
||||
|
||||
paths:
|
||||
'/{schema}/{entity}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
get:
|
||||
summary: Get table metadata
|
||||
description: Retrieve table metadata including columns, types, and relationships
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Response'
|
||||
- type: object
|
||||
properties:
|
||||
data:
|
||||
$ref: '#/components/schemas/TableMetadata'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
post:
|
||||
summary: Perform operations on entities
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
'/{schema}/{entity}/{id}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
summary: Perform operations on a specific entity
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Request:
|
||||
type: object
|
||||
required:
|
||||
- operation
|
||||
properties:
|
||||
operation:
|
||||
type: string
|
||||
enum:
|
||||
- read
|
||||
- create
|
||||
- update
|
||||
- delete
|
||||
id:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: Optional record identifier(s) when not provided in URL
|
||||
data:
|
||||
oneOf:
|
||||
- type: object
|
||||
- type: array
|
||||
items:
|
||||
type: object
|
||||
description: Data for single or bulk create/update operations
|
||||
options:
|
||||
$ref: '#/components/schemas/Options'
|
||||
|
||||
Options:
|
||||
type: object
|
||||
properties:
|
||||
preload:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/PreloadOption'
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
sort:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/SortOption'
|
||||
limit:
|
||||
type: integer
|
||||
minimum: 0
|
||||
offset:
|
||||
type: integer
|
||||
minimum: 0
|
||||
customOperators:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/CustomOperator'
|
||||
computedColumns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ComputedColumn'
|
||||
|
||||
PreloadOption:
|
||||
type: object
|
||||
properties:
|
||||
relation:
|
||||
type: string
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
|
||||
FilterOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- operator
|
||||
- value
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
operator:
|
||||
type: string
|
||||
enum:
|
||||
- eq
|
||||
- neq
|
||||
- gt
|
||||
- gte
|
||||
- lt
|
||||
- lte
|
||||
- like
|
||||
- ilike
|
||||
- in
|
||||
value:
|
||||
type: object
|
||||
|
||||
SortOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- direction
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
direction:
|
||||
type: string
|
||||
enum:
|
||||
- asc
|
||||
- desc
|
||||
|
||||
CustomOperator:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- sql
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
sql:
|
||||
type: string
|
||||
|
||||
ComputedColumn:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- expression
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
expression:
|
||||
type: string
|
||||
|
||||
Response:
|
||||
type: object
|
||||
required:
|
||||
- success
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
data:
|
||||
type: object
|
||||
metadata:
|
||||
$ref: '#/components/schemas/Metadata'
|
||||
error:
|
||||
$ref: '#/components/schemas/Error'
|
||||
|
||||
Metadata:
|
||||
type: object
|
||||
properties:
|
||||
total:
|
||||
type: integer
|
||||
filtered:
|
||||
type: integer
|
||||
limit:
|
||||
type: integer
|
||||
offset:
|
||||
type: integer
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
details:
|
||||
type: object
|
||||
|
||||
TableMetadata:
|
||||
type: object
|
||||
required:
|
||||
- schema
|
||||
- table
|
||||
- columns
|
||||
- relations
|
||||
properties:
|
||||
schema:
|
||||
type: string
|
||||
description: Schema name
|
||||
table:
|
||||
type: string
|
||||
description: Table name
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Column'
|
||||
relations:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: List of relation names
|
||||
|
||||
Column:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- is_nullable
|
||||
- is_primary
|
||||
- is_unique
|
||||
- has_index
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Column name
|
||||
type:
|
||||
type: string
|
||||
description: Data type of the column
|
||||
is_nullable:
|
||||
type: boolean
|
||||
description: Whether the column can contain null values
|
||||
is_primary:
|
||||
type: boolean
|
||||
description: Whether the column is a primary key
|
||||
is_unique:
|
||||
type: boolean
|
||||
description: Whether the column has a unique constraint
|
||||
has_index:
|
||||
type: boolean
|
||||
description: Whether the column is indexed
|
||||
|
||||
responses:
|
||||
BadRequest:
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
ServerError:
|
||||
description: Internal server error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
securitySchemes:
|
||||
bearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
|
||||
security:
|
||||
- bearerAuth: []
|
||||
@@ -94,12 +94,16 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
db *bun.DB
|
||||
db *bun.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewBunAdapter creates a new Bun adapter
|
||||
func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
return &BunAdapter{db: db}
|
||||
adapter := &BunAdapter{db: db}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
@@ -126,8 +130,9 @@ func (b *BunAdapter) DisableQueryDebug() {
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
db: b.db,
|
||||
query: b.db.NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +173,7 @@ func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
return nil, err
|
||||
}
|
||||
// For Bun, we'll return a special wrapper that holds the transaction
|
||||
return &BunTxAdapter{tx: tx}, nil
|
||||
return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
|
||||
}
|
||||
|
||||
func (b *BunAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -191,7 +196,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx}
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
@@ -200,6 +205,20 @@ func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
}
|
||||
|
||||
func (b *BunAdapter) DriverName() string {
|
||||
// Normalize Bun's dialect name to match the project's canonical vocabulary.
|
||||
// Bun returns "pg" for PostgreSQL; the rest of the project uses "postgres".
|
||||
// Bun returns "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := b.db.Dialect().Name().String(); name {
|
||||
case "pg":
|
||||
return "postgres"
|
||||
case "sqlite3":
|
||||
return "sqlite"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// BunSelectQuery implements SelectQuery for Bun
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
@@ -208,9 +227,11 @@ type BunSelectQuery struct {
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
@@ -221,7 +242,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
b.schema, b.tableName = parseTableName(fullTableName)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
@@ -234,7 +256,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||
b.query = b.query.Table(table)
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
b.schema, b.tableName = parseTableName(table)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -480,6 +503,25 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Check if this relation will likely cause alias truncation FIRST
|
||||
// PostgreSQL has a 63-character limit on identifiers
|
||||
willTruncate := checkAliasLength(relation)
|
||||
|
||||
if willTruncate {
|
||||
logger.Warn("Preload relation '%s' would generate aliases exceeding PostgreSQL's 63-char limit", relation)
|
||||
logger.Info("Using custom preload implementation with separate queries for relation '%s'", relation)
|
||||
|
||||
// Store this relation for custom post-processing after the main query
|
||||
// We'll load it manually with separate queries to avoid JOIN aliases
|
||||
if b.customPreloads == nil {
|
||||
b.customPreloads = make(map[string][]func(common.SelectQuery) common.SelectQuery)
|
||||
}
|
||||
b.customPreloads[relation] = apply
|
||||
|
||||
// Return without calling Bun's Relation() - we'll handle it ourselves
|
||||
return b
|
||||
}
|
||||
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
|
||||
if !b.skipAutoDetect {
|
||||
@@ -490,8 +532,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
// If this is a belongs-to or has-one relation that won't exceed limits, use JOIN for better performance
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return b.JoinRelation(relation, apply...)
|
||||
}
|
||||
@@ -504,6 +546,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
}
|
||||
|
||||
// Use Bun's native Relation() for preloading
|
||||
// Note: For relations that would cause truncation, skipAutoDetect is set to true
|
||||
// to prevent our auto-detection from adding JOIN optimization
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -519,8 +563,9 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
|
||||
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||
wrapper := &BunSelectQuery{
|
||||
query: sq,
|
||||
db: b.db,
|
||||
query: sq,
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
}
|
||||
|
||||
// Try to extract table name and alias from the preload model
|
||||
@@ -530,7 +575,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
// Extract table name if model implements TableNameProvider
|
||||
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
wrapper.schema, wrapper.tableName = parseTableName(fullTableName, b.driverName)
|
||||
}
|
||||
|
||||
// Extract table alias if model implements TableAliasProvider
|
||||
@@ -561,6 +607,502 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
return b
|
||||
}
|
||||
|
||||
// checkIfRelationAlreadyLoaded checks if a relation is already populated on parent records
|
||||
// Returns the collection of related records if already loaded
|
||||
func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (reflect.Value, bool) {
|
||||
if parents.Len() == 0 {
|
||||
return reflect.Value{}, false
|
||||
}
|
||||
|
||||
// Get the first parent to check the relation field
|
||||
firstParent := parents.Index(0)
|
||||
if firstParent.Kind() == reflect.Ptr {
|
||||
firstParent = firstParent.Elem()
|
||||
}
|
||||
|
||||
// Find the relation field
|
||||
relationField := firstParent.FieldByName(relationName)
|
||||
if !relationField.IsValid() {
|
||||
return reflect.Value{}, false
|
||||
}
|
||||
|
||||
// Check if it's a slice (has-many)
|
||||
if relationField.Kind() == reflect.Slice {
|
||||
// Check if any parent has a non-empty slice
|
||||
for i := 0; i < parents.Len(); i++ {
|
||||
parent := parents.Index(i)
|
||||
if parent.Kind() == reflect.Ptr {
|
||||
parent = parent.Elem()
|
||||
}
|
||||
field := parent.FieldByName(relationName)
|
||||
if field.IsValid() && !field.IsNil() && field.Len() > 0 {
|
||||
// Already loaded! Collect all related records from all parents
|
||||
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
|
||||
for j := 0; j < parents.Len(); j++ {
|
||||
p := parents.Index(j)
|
||||
if p.Kind() == reflect.Ptr {
|
||||
p = p.Elem()
|
||||
}
|
||||
f := p.FieldByName(relationName)
|
||||
if f.IsValid() && !f.IsNil() {
|
||||
for k := 0; k < f.Len(); k++ {
|
||||
allRelated = reflect.Append(allRelated, f.Index(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
return allRelated, true
|
||||
}
|
||||
}
|
||||
} else if relationField.Kind() == reflect.Ptr {
|
||||
// Check if it's a pointer (has-one/belongs-to)
|
||||
if !relationField.IsNil() {
|
||||
// Already loaded! Collect all related records from all parents
|
||||
relatedType := relationField.Type()
|
||||
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
|
||||
for j := 0; j < parents.Len(); j++ {
|
||||
p := parents.Index(j)
|
||||
if p.Kind() == reflect.Ptr {
|
||||
p = p.Elem()
|
||||
}
|
||||
f := p.FieldByName(relationName)
|
||||
if f.IsValid() && !f.IsNil() {
|
||||
allRelated = reflect.Append(allRelated, f)
|
||||
}
|
||||
}
|
||||
return allRelated, true
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Value{}, false
|
||||
}
|
||||
|
||||
// loadCustomPreloads loads relations that would cause alias truncation using separate queries
|
||||
func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
|
||||
model := b.query.GetModel()
|
||||
if model == nil || model.Value() == nil {
|
||||
return fmt.Errorf("no model to load preloads for")
|
||||
}
|
||||
|
||||
// Get the actual data from the model
|
||||
modelValue := reflect.ValueOf(model.Value())
|
||||
if modelValue.Kind() == reflect.Ptr {
|
||||
modelValue = modelValue.Elem()
|
||||
}
|
||||
|
||||
// We only handle slices of records for now
|
||||
if modelValue.Kind() != reflect.Slice {
|
||||
logger.Warn("Custom preloads only support slice models currently, got: %v", modelValue.Kind())
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelValue.Len() == 0 {
|
||||
logger.Debug("No records to load preloads for")
|
||||
return nil
|
||||
}
|
||||
|
||||
// For each custom preload relation
|
||||
for relation, applyFuncs := range b.customPreloads {
|
||||
logger.Info("Loading custom preload for relation: %s", relation)
|
||||
|
||||
// Parse the relation path (e.g., "MTL.MAL.DEF" -> ["MTL", "MAL", "DEF"])
|
||||
relationParts := strings.Split(relation, ".")
|
||||
|
||||
// Start with the parent records
|
||||
currentRecords := modelValue
|
||||
|
||||
// Load each level of the relation
|
||||
for i, relationPart := range relationParts {
|
||||
isLastPart := i == len(relationParts)-1
|
||||
|
||||
logger.Debug("Loading relation part [%d/%d]: %s", i+1, len(relationParts), relationPart)
|
||||
|
||||
// Check if this level is already loaded by Bun (avoid duplicates)
|
||||
existingRecords, alreadyLoaded := checkIfRelationAlreadyLoaded(currentRecords, relationPart)
|
||||
if alreadyLoaded && existingRecords.IsValid() && existingRecords.Len() > 0 {
|
||||
logger.Info("Relation '%s' already loaded by Bun, using existing %d records", relationPart, existingRecords.Len())
|
||||
currentRecords = existingRecords
|
||||
continue
|
||||
}
|
||||
|
||||
// Load this level and get the loaded records for the next level
|
||||
loadedRecords, err := b.loadRelationLevel(ctx, currentRecords, relationPart, isLastPart, applyFuncs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load relation %s (part %s): %w", relation, relationPart, err)
|
||||
}
|
||||
|
||||
// For nested relations, use the loaded records as parents for the next level
|
||||
if !isLastPart && loadedRecords.IsValid() && loadedRecords.Len() > 0 {
|
||||
logger.Debug("Collected %d records for next level", loadedRecords.Len())
|
||||
currentRecords = loadedRecords
|
||||
} else if !isLastPart {
|
||||
logger.Debug("No records loaded at level %s, stopping nested preload", relationPart)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRelationLevel loads a single level of a relation for a set of parent records
|
||||
// Returns the loaded records (for use as parents in nested preloads) and any error
|
||||
func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords reflect.Value, relationName string, isLast bool, applyFuncs []func(common.SelectQuery) common.SelectQuery) (reflect.Value, error) {
|
||||
if parentRecords.Len() == 0 {
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// Get the first record to inspect the struct type
|
||||
firstRecord := parentRecords.Index(0)
|
||||
if firstRecord.Kind() == reflect.Ptr {
|
||||
firstRecord = firstRecord.Elem()
|
||||
}
|
||||
|
||||
if firstRecord.Kind() != reflect.Struct {
|
||||
return reflect.Value{}, fmt.Errorf("expected struct, got %v", firstRecord.Kind())
|
||||
}
|
||||
|
||||
parentType := firstRecord.Type()
|
||||
|
||||
// Find the relation field in the struct
|
||||
structField, found := parentType.FieldByName(relationName)
|
||||
if !found {
|
||||
return reflect.Value{}, fmt.Errorf("relation field %s not found in struct %s", relationName, parentType.Name())
|
||||
}
|
||||
|
||||
// Parse the bun tag to get relation info
|
||||
bunTag := structField.Tag.Get("bun")
|
||||
logger.Debug("Relation %s bun tag: %s", relationName, bunTag)
|
||||
|
||||
relInfo, err := parseRelationTag(bunTag)
|
||||
if err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to parse relation tag for %s: %w", relationName, err)
|
||||
}
|
||||
|
||||
logger.Debug("Parsed relation: type=%s, join=%s", relInfo.relType, relInfo.joinCondition)
|
||||
|
||||
// Extract foreign key values from parent records
|
||||
fkValues, err := extractForeignKeyValues(parentRecords, relInfo.localKey)
|
||||
if err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to extract FK values: %w", err)
|
||||
}
|
||||
|
||||
if len(fkValues) == 0 {
|
||||
logger.Debug("No foreign key values to load for relation %s", relationName)
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
logger.Debug("Loading %d related records for %s (FK values: %v)", len(fkValues), relationName, fkValues)
|
||||
|
||||
// Get the related model type
|
||||
relatedType := structField.Type
|
||||
isSlice := relatedType.Kind() == reflect.Slice
|
||||
if isSlice {
|
||||
relatedType = relatedType.Elem()
|
||||
}
|
||||
if relatedType.Kind() == reflect.Ptr {
|
||||
relatedType = relatedType.Elem()
|
||||
}
|
||||
|
||||
// Create a slice to hold the results
|
||||
resultsSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(relatedType)), 0, len(fkValues))
|
||||
resultsPtr := reflect.New(resultsSlice.Type())
|
||||
resultsPtr.Elem().Set(resultsSlice)
|
||||
|
||||
// Build and execute the query
|
||||
query := b.db.NewSelect().Model(resultsPtr.Interface())
|
||||
|
||||
// Apply WHERE clause: foreign_key IN (values...)
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", relInfo.foreignKey), bun.In(fkValues))
|
||||
|
||||
// Apply user's functions (if any)
|
||||
if isLast && len(applyFuncs) > 0 {
|
||||
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
|
||||
for _, fn := range applyFuncs {
|
||||
if fn != nil {
|
||||
wrapper = fn(wrapper).(*BunSelectQuery)
|
||||
query = wrapper.query
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
err = query.Scan(ctx)
|
||||
if err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("failed to load related records: %w", err)
|
||||
}
|
||||
|
||||
loadedRecords := resultsPtr.Elem()
|
||||
logger.Info("Loaded %d related records for relation %s", loadedRecords.Len(), relationName)
|
||||
|
||||
// Associate loaded records back to parent records
|
||||
err = associateRelatedRecords(parentRecords, loadedRecords, relationName, relInfo, isSlice)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Return the loaded records for use in nested preloads
|
||||
return loadedRecords, nil
|
||||
}
|
||||
|
||||
// relationInfo holds parsed relation metadata
|
||||
type relationInfo struct {
|
||||
relType string // has-one, has-many, belongs-to
|
||||
localKey string // Key in parent table
|
||||
foreignKey string // Key in related table
|
||||
joinCondition string // Full join condition
|
||||
}
|
||||
|
||||
// parseRelationTag parses the bun:"rel:..." tag
|
||||
func parseRelationTag(tag string) (*relationInfo, error) {
|
||||
info := &relationInfo{}
|
||||
|
||||
// Parse tag like: rel:has-one,join:rid_mastertaskitem=rid_mastertaskitem
|
||||
parts := strings.Split(tag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "rel:") {
|
||||
info.relType = strings.TrimPrefix(part, "rel:")
|
||||
} else if strings.HasPrefix(part, "join:") {
|
||||
info.joinCondition = strings.TrimPrefix(part, "join:")
|
||||
// Parse join: local_key=foreign_key
|
||||
joinParts := strings.Split(info.joinCondition, "=")
|
||||
if len(joinParts) == 2 {
|
||||
info.localKey = strings.TrimSpace(joinParts[0])
|
||||
info.foreignKey = strings.TrimSpace(joinParts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.relType == "" || info.localKey == "" || info.foreignKey == "" {
|
||||
return nil, fmt.Errorf("incomplete relation tag: %s", tag)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// extractForeignKeyValues collects FK values from parent records
|
||||
func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]interface{}, error) {
|
||||
values := make([]interface{}, 0, records.Len())
|
||||
seenValues := make(map[interface{}]bool)
|
||||
|
||||
for i := 0; i < records.Len(); i++ {
|
||||
record := records.Index(i)
|
||||
if record.Kind() == reflect.Ptr {
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Find the FK field - try both exact name and capitalized version
|
||||
fkField := record.FieldByName(fkFieldName)
|
||||
if !fkField.IsValid() {
|
||||
// Try capitalized version
|
||||
fkField = record.FieldByName(strings.ToUpper(fkFieldName[:1]) + fkFieldName[1:])
|
||||
}
|
||||
if !fkField.IsValid() {
|
||||
// Try finding by json tag
|
||||
for j := 0; j < record.NumField(); j++ {
|
||||
field := record.Type().Field(j)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.HasPrefix(jsonTag, fkFieldName) || strings.Contains(bunTag, fkFieldName) {
|
||||
fkField = record.Field(j)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !fkField.IsValid() {
|
||||
continue // Skip records without FK
|
||||
}
|
||||
|
||||
// Extract the value
|
||||
var value interface{}
|
||||
if fkField.CanInterface() {
|
||||
value = fkField.Interface()
|
||||
|
||||
// Handle SqlNull types
|
||||
if nullType, ok := value.(interface{ IsNull() bool }); ok {
|
||||
if nullType.IsNull() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle types with Int64() method
|
||||
if int64er, ok := value.(interface{ Int64() int64 }); ok {
|
||||
value = int64er.Int64()
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
if !seenValues[value] {
|
||||
values = append(values, value)
|
||||
seenValues[value] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// associateRelatedRecords associates loaded records back to parents
|
||||
func associateRelatedRecords(parents, related reflect.Value, fieldName string, relInfo *relationInfo, isSlice bool) error {
|
||||
logger.Debug("Associating %d related records to %d parents for field '%s'", related.Len(), parents.Len(), fieldName)
|
||||
|
||||
// Build a map: foreignKey -> related record(s)
|
||||
relatedMap := make(map[interface{}][]reflect.Value)
|
||||
|
||||
for i := 0; i < related.Len(); i++ {
|
||||
relRecord := related.Index(i)
|
||||
relRecordElem := relRecord
|
||||
if relRecordElem.Kind() == reflect.Ptr {
|
||||
relRecordElem = relRecordElem.Elem()
|
||||
}
|
||||
|
||||
// Get the foreign key value from the related record - try multiple variations
|
||||
fkField := findFieldByName(relRecordElem, relInfo.foreignKey)
|
||||
if !fkField.IsValid() {
|
||||
logger.Warn("Could not find FK field '%s' in related record type %s", relInfo.foreignKey, relRecordElem.Type().Name())
|
||||
continue
|
||||
}
|
||||
|
||||
fkValue := extractFieldValue(fkField)
|
||||
if fkValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
relatedMap[fkValue] = append(relatedMap[fkValue], related.Index(i))
|
||||
}
|
||||
|
||||
logger.Debug("Built related map with %d unique FK values", len(relatedMap))
|
||||
|
||||
// Associate with parents
|
||||
associatedCount := 0
|
||||
for i := 0; i < parents.Len(); i++ {
|
||||
parentPtr := parents.Index(i)
|
||||
parent := parentPtr
|
||||
if parent.Kind() == reflect.Ptr {
|
||||
parent = parent.Elem()
|
||||
}
|
||||
|
||||
// Get the local key value from parent
|
||||
localField := findFieldByName(parent, relInfo.localKey)
|
||||
if !localField.IsValid() {
|
||||
logger.Warn("Could not find local key field '%s' in parent type %s", relInfo.localKey, parent.Type().Name())
|
||||
continue
|
||||
}
|
||||
|
||||
localValue := extractFieldValue(localField)
|
||||
if localValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find matching related records
|
||||
matches := relatedMap[localValue]
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Set the relation field - IMPORTANT: use the pointer, not the elem
|
||||
relationField := parent.FieldByName(fieldName)
|
||||
if !relationField.IsValid() {
|
||||
logger.Warn("Relation field '%s' not found in parent type %s", fieldName, parent.Type().Name())
|
||||
continue
|
||||
}
|
||||
|
||||
if !relationField.CanSet() {
|
||||
logger.Warn("Relation field '%s' cannot be set (unexported?)", fieldName)
|
||||
continue
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
// For has-many: replace entire slice (don't append to avoid duplicates)
|
||||
newSlice := reflect.MakeSlice(relationField.Type(), 0, len(matches))
|
||||
for _, match := range matches {
|
||||
newSlice = reflect.Append(newSlice, match)
|
||||
}
|
||||
relationField.Set(newSlice)
|
||||
associatedCount += len(matches)
|
||||
logger.Debug("Set has-many field '%s' with %d records for parent %d", fieldName, len(matches), i)
|
||||
} else {
|
||||
// For has-one/belongs-to: only set if not already set (avoid duplicates)
|
||||
if relationField.IsNil() {
|
||||
relationField.Set(matches[0])
|
||||
associatedCount++
|
||||
logger.Debug("Set has-one field '%s' for parent %d", fieldName, i)
|
||||
} else {
|
||||
logger.Debug("Skipping has-one field '%s' for parent %d (already set)", fieldName, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Associated %d related records to %d parents for field '%s'", associatedCount, parents.Len(), fieldName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// findFieldByName finds a struct field by name, trying multiple variations
|
||||
func findFieldByName(v reflect.Value, name string) reflect.Value {
|
||||
// Try exact name
|
||||
field := v.FieldByName(name)
|
||||
if field.IsValid() {
|
||||
return field
|
||||
}
|
||||
|
||||
// Try with capital first letter
|
||||
if len(name) > 0 {
|
||||
capital := strings.ToUpper(name[0:1]) + name[1:]
|
||||
field = v.FieldByName(capital)
|
||||
if field.IsValid() {
|
||||
return field
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching by json or bun tag
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
jsonTag := f.Tag.Get("json")
|
||||
bunTag := f.Tag.Get("bun")
|
||||
|
||||
// Check json tag
|
||||
if strings.HasPrefix(jsonTag, name+",") || jsonTag == name {
|
||||
return v.Field(i)
|
||||
}
|
||||
|
||||
// Check bun tag for column name
|
||||
if strings.Contains(bunTag, name+",") || strings.Contains(bunTag, name+":") {
|
||||
return v.Field(i)
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Value{}
|
||||
}
|
||||
|
||||
// extractFieldValue extracts the value from a field, handling SqlNull types
|
||||
func extractFieldValue(field reflect.Value) interface{} {
|
||||
if !field.CanInterface() {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := field.Interface()
|
||||
|
||||
// Handle SqlNull types
|
||||
if nullType, ok := value.(interface{ IsNull() bool }); ok {
|
||||
if nullType.IsNull() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle types with Int64() method
|
||||
if int64er, ok := value.(interface{ Int64() int64 }); ok {
|
||||
return int64er.Int64()
|
||||
}
|
||||
|
||||
// Handle types with String() method for comparison
|
||||
if stringer, ok := value.(interface{ String() string }); ok {
|
||||
return stringer.String()
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// JoinRelation uses a LEFT JOIN instead of a separate query
|
||||
// This is more efficient for many-to-one or one-to-one relationships
|
||||
@@ -700,6 +1242,15 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// After main query, load custom preloads using separate queries
|
||||
if len(b.customPreloads) > 0 {
|
||||
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
||||
if err := b.loadCustomPreloads(ctx); err != nil {
|
||||
logger.Error("Failed to load custom preloads: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -950,13 +1501,15 @@ func (b *BunResult) LastInsertId() (int64, error) {
|
||||
|
||||
// BunTxAdapter wraps a Bun transaction to implement the Database interface
|
||||
type BunTxAdapter struct {
|
||||
tx bun.Tx
|
||||
tx bun.Tx
|
||||
driverName string
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
driverName: b.driverName,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1000,3 +1553,7 @@ func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.tx
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) DriverName() string {
|
||||
return b.driverName
|
||||
}
|
||||
|
||||
@@ -15,12 +15,16 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewGormAdapter creates a new GORM adapter
|
||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return &GormAdapter{db: db}
|
||||
adapter := &GormAdapter{db: db}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
@@ -40,7 +44,7 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db}
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
@@ -79,7 +83,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx}, nil
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -97,7 +101,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx}
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
@@ -106,12 +110,30 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
return "sqlite"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
@@ -123,7 +145,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(fullTableName)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
@@ -136,7 +159,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
g.db = g.db.Table(table)
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(table)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
|
||||
return g
|
||||
}
|
||||
@@ -322,7 +346,8 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
@@ -360,6 +385,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
}
|
||||
|
||||
@@ -16,12 +16,19 @@ import (
|
||||
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewPgSQLAdapter creates a new PostgreSQL adapter
|
||||
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
|
||||
return &PgSQLAdapter{db: db}
|
||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||
// An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided;
|
||||
// it defaults to "postgres" when omitted.
|
||||
func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
name := "postgres"
|
||||
if len(driverName) > 0 && driverName[0] != "" {
|
||||
name = driverName[0]
|
||||
}
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
@@ -31,22 +38,25 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
values: make(map[string]interface{}),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -56,6 +66,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -98,7 +109,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PgSQLTxAdapter{tx: tx}, nil
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -121,7 +132,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return err
|
||||
}
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx}
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
@@ -141,6 +152,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// preloadConfig represents a relationship to be preloaded
|
||||
type preloadConfig struct {
|
||||
relation string
|
||||
@@ -165,6 +180,7 @@ type PgSQLSelectQuery struct {
|
||||
model interface{}
|
||||
tableName string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
@@ -183,7 +199,9 @@ type PgSQLSelectQuery struct {
|
||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
p.tableAlias = provider.TableAlias()
|
||||
@@ -192,7 +210,8 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -501,16 +520,19 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
||||
|
||||
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
||||
type PgSQLInsertQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
// Extract values from model using reflection
|
||||
// This is a simplified implementation
|
||||
@@ -518,7 +540,8 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -591,6 +614,7 @@ type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
whereClauses []string
|
||||
@@ -602,13 +626,16 @@ type PgSQLUpdateQuery struct {
|
||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.model == nil {
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
@@ -749,6 +776,7 @@ type PgSQLDeleteQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
@@ -756,13 +784,16 @@ type PgSQLDeleteQuery struct {
|
||||
|
||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -835,27 +866,31 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
||||
|
||||
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
||||
type PgSQLTxAdapter struct {
|
||||
tx *sql.Tx
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
tx: p.tx,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
tx: p.tx,
|
||||
values: make(map[string]interface{}),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -865,6 +900,7 @@ func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -912,6 +948,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.tx
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
||||
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
||||
for _, preload := range p.preloads {
|
||||
@@ -1036,9 +1076,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
||||
// Create a new select query for the related table
|
||||
var db common.Database
|
||||
if p.tx != nil {
|
||||
db = &PgSQLTxAdapter{tx: p.tx}
|
||||
db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
|
||||
} else {
|
||||
db = &PgSQLAdapter{db: p.db}
|
||||
db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
|
||||
}
|
||||
|
||||
query := db.NewSelect().
|
||||
|
||||
@@ -11,15 +11,71 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
|
||||
const postgresIdentifierLimit = 63
|
||||
|
||||
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
|
||||
// Returns true if the alias is likely to be truncated
|
||||
func checkAliasLength(relation string) bool {
|
||||
// Bun generates aliases like: parentalias__childalias__columnname
|
||||
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
|
||||
parts := strings.Split(relation, ".")
|
||||
if len(parts) <= 1 {
|
||||
return false // Single level relations are fine
|
||||
}
|
||||
|
||||
// Calculate the actual alias prefix length that Bun will generate
|
||||
// Bun uses double underscores (__) between each relation level
|
||||
// and converts the relation names to lowercase with underscores
|
||||
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
|
||||
aliasPrefixLen := len(aliasPrefix)
|
||||
|
||||
// We need to add 2 more underscores for the column name separator plus column name length
|
||||
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
|
||||
// To be safe, assume the longest column name could be around 35 chars
|
||||
maxColumnNameLen := 35
|
||||
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
|
||||
|
||||
// Check if this would exceed PostgreSQL's identifier limit
|
||||
if estimatedMaxLen > postgresIdentifierLimit {
|
||||
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
|
||||
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check if just the prefix is getting close (within 15 chars of limit)
|
||||
// This gives room for column names
|
||||
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
|
||||
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
|
||||
relation, aliasPrefixLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
// For example: "public.users" -> ("public", "users")
|
||||
//
|
||||
// "users" -> ("", "users")
|
||||
func parseTableName(fullTableName string) (schema, table string) {
|
||||
//
|
||||
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
|
||||
// in the same way as PostgreSQL/MSSQL
|
||||
func parseTableName(fullTableName, driverName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
return fullTableName[:idx], fullTableName[idx+1:]
|
||||
schema = fullTableName[:idx]
|
||||
table = fullTableName[idx+1:]
|
||||
|
||||
// For SQLite, convert schema.table to schema_table
|
||||
if driverName == "sqlite" || driverName == "sqlite3" {
|
||||
table = schema + "_" + table
|
||||
schema = ""
|
||||
}
|
||||
return schema, table
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
@@ -30,6 +30,12 @@ type Database interface {
|
||||
// For Bun, this returns *bun.DB
|
||||
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||
GetUnderlyingDB() interface{}
|
||||
|
||||
// DriverName returns the canonical name of the underlying database driver.
|
||||
// Possible values: "postgres", "sqlite", "mssql", "mysql".
|
||||
// All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's
|
||||
// "sqlserver") to the values above before returning.
|
||||
DriverName() string
|
||||
}
|
||||
|
||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||
|
||||
@@ -50,6 +50,9 @@ func (m *mockDatabase) RollbackTx(ctx context.Context) error {
|
||||
func (m *mockDatabase) GetUnderlyingDB() interface{} {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// Mock SelectQuery
|
||||
type mockSelectQuery struct{}
|
||||
|
||||
@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
|
||||
- **GORM** - Popular Go ORM
|
||||
- **Native** - Standard library `*sql.DB`
|
||||
- All three share the same underlying connection pool
|
||||
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
|
||||
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||
- **Production-Ready Features**:
|
||||
- Automatic health checks and reconnection
|
||||
@@ -179,6 +180,35 @@ if err != nil {
|
||||
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||
```
|
||||
|
||||
#### Cross-Database Example with SQLite
|
||||
|
||||
```go
|
||||
// Same model works across all databases
|
||||
type User struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Username string `bun:"username"`
|
||||
Email string `bun:"email"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "auth.users"
|
||||
}
|
||||
|
||||
// PostgreSQL connection
|
||||
pgConn, _ := mgr.Get("primary")
|
||||
pgDB, _ := pgConn.Bun()
|
||||
var pgUsers []User
|
||||
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth.users
|
||||
|
||||
// SQLite connection
|
||||
sqliteConn, _ := mgr.Get("cache-db")
|
||||
sqliteDB, _ := sqliteConn.Bun()
|
||||
var sqliteUsers []User
|
||||
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
|
||||
```
|
||||
|
||||
#### Use MongoDB
|
||||
|
||||
```go
|
||||
@@ -368,6 +398,37 @@ Providers handle:
|
||||
- Connection statistics
|
||||
- Connection cleanup
|
||||
|
||||
### SQLite Schema Handling
|
||||
|
||||
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
|
||||
|
||||
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
|
||||
|
||||
```go
|
||||
// Model definition (works across all databases)
|
||||
func (User) TableName() string {
|
||||
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
|
||||
// SQLite: "auth_users"
|
||||
}
|
||||
|
||||
// Query execution
|
||||
db.NewSelect().Model(&User{}).Scan(ctx)
|
||||
// PostgreSQL/MSSQL: SELECT * FROM auth.users
|
||||
// SQLite: SELECT * FROM auth_users
|
||||
```
|
||||
|
||||
**How it Works**:
|
||||
- Bun, GORM, and Native adapters detect the driver type
|
||||
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
|
||||
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
|
||||
- Preload and relation queries are also handled automatically
|
||||
|
||||
**Benefits**:
|
||||
- Write database-agnostic code
|
||||
- Use the same models across PostgreSQL, MSSQL, and SQLite
|
||||
- No conditional logic needed in your application
|
||||
- Schema separation maintained through naming convention in SQLite
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||
|
||||
@@ -467,13 +467,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeSQLite:
|
||||
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeMSSQL:
|
||||
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -231,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
|
||||
// Close closes all database connections
|
||||
func (m *connectionManager) Close() error {
|
||||
// Stop the health checker before taking mu. performHealthCheck acquires
|
||||
// a read lock, so waiting for the goroutine while holding the write lock
|
||||
// would deadlock.
|
||||
m.stopHealthChecker()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Stop health checker
|
||||
m.stopHealthChecker()
|
||||
|
||||
// Close all connections
|
||||
var errors []error
|
||||
for name, conn := range m.connections {
|
||||
|
||||
@@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
|
||||
@@ -645,11 +645,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
|
||||
// Database operation helpers (adapted from websocketspec)
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
tableName = schema + "_" + tableName
|
||||
} else {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
572
pkg/resolvespec/EXAMPLES.md
Normal file
572
pkg/resolvespec/EXAMPLES.md
Normal file
@@ -0,0 +1,572 @@
|
||||
# ResolveSpec Query Features Examples
|
||||
|
||||
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
|
||||
|
||||
## OR Logic in Filters (SearchOr)
|
||||
|
||||
### Basic OR Filter Example
|
||||
|
||||
Find all users with status "active" OR "pending":
|
||||
|
||||
```json
|
||||
POST /users
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combined AND/OR Filters
|
||||
|
||||
Find users with (status="active" OR status="pending") AND age >= 18:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
**Important Notes:**
|
||||
- By default, filters use AND logic
|
||||
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
|
||||
- This grouping ensures OR conditions don't interfere with AND conditions
|
||||
- You don't need to specify `"logic_operator": "AND"` as it's the default
|
||||
|
||||
### Multiple OR Groups
|
||||
|
||||
You can have multiple separate OR groups:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "high"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "urgent",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
|
||||
|
||||
## Custom Operators
|
||||
|
||||
### Simple Custom SQL Condition
|
||||
|
||||
Filter by email domain using custom SQL:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "company_emails",
|
||||
"sql": "email LIKE '%@company.com'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Custom Operators
|
||||
|
||||
Combine multiple custom SQL conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_active",
|
||||
"sql": "last_login > NOW() - INTERVAL '30 days'"
|
||||
},
|
||||
{
|
||||
"name": "high_score",
|
||||
"sql": "score > 1000"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Complex Custom Operator
|
||||
|
||||
Use complex SQL expressions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "priority_users",
|
||||
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combining Custom Operators with Regular Filters
|
||||
|
||||
Mix custom operators with standard filters:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "active_last_month",
|
||||
"sql": "last_activity > NOW() - INTERVAL '1 month'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Row Numbers
|
||||
|
||||
### Two Ways to Get Row Numbers
|
||||
|
||||
There are two different features for row numbers:
|
||||
|
||||
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
|
||||
2. **`RowNumber` field in models** - Automatically number all records in the response
|
||||
|
||||
### 1. FetchRowNumber - Get Position of Specific Record
|
||||
|
||||
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Contains ONLY the specified user:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Alice Smith",
|
||||
"score": 9850,
|
||||
"level": 42
|
||||
},
|
||||
"metadata": {
|
||||
"total": 10000,
|
||||
"count": 1,
|
||||
"filtered": 10000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
|
||||
|
||||
### Row Number with Filters
|
||||
|
||||
Find position within a filtered subset (e.g., "What's my rank in my country?"):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Bob Johnson",
|
||||
"country": "USA",
|
||||
"score": 7200,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 2500,
|
||||
"count": 1,
|
||||
"filtered": 2500,
|
||||
"row_number": 156
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
|
||||
|
||||
### 2. RowNumber Field - Auto-Number All Records
|
||||
|
||||
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
|
||||
|
||||
**Model Definition:**
|
||||
```go
|
||||
type Player struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Score int64 `json:"score"`
|
||||
RowNumber int64 `json:"row_number"` // Will be auto-populated
|
||||
}
|
||||
```
|
||||
|
||||
**Request (with pagination):**
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [{"column": "score", "direction": "desc"}],
|
||||
"limit": 10,
|
||||
"offset": 20
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - RowNumber automatically set:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [
|
||||
{
|
||||
"id": 456,
|
||||
"name": "Player21",
|
||||
"score": 8900,
|
||||
"row_number": 21
|
||||
},
|
||||
{
|
||||
"id": 789,
|
||||
"name": "Player22",
|
||||
"score": 8850,
|
||||
"row_number": 22
|
||||
},
|
||||
{
|
||||
"id": 123,
|
||||
"name": "Player23",
|
||||
"score": 8800,
|
||||
"row_number": 23
|
||||
}
|
||||
// ... records 24-30 ...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**How It Works:**
|
||||
- `row_number = offset + index + 1` (1-based)
|
||||
- With offset=20, first record gets row_number=21
|
||||
- With offset=20, second record gets row_number=22
|
||||
- Perfect for displaying "Rank" in paginated tables
|
||||
|
||||
**Use Case:** Displaying leaderboards with rank numbers:
|
||||
```
|
||||
Rank | Player | Score
|
||||
-----|-----------|-------
|
||||
21 | Player21 | 8900
|
||||
22 | Player22 | 8850
|
||||
23 | Player23 | 8800
|
||||
```
|
||||
|
||||
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
|
||||
|
||||
### When to Use Each Feature
|
||||
|
||||
| Feature | Use Case | Returns | Performance |
|
||||
|---------|----------|---------|-------------|
|
||||
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
|
||||
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
|
||||
|
||||
**Combined Example - Full Leaderboard UI:**
|
||||
|
||||
```javascript
|
||||
// Request 1: Get current user's rank
|
||||
const userRank = await api.read({
|
||||
fetch_row_number: currentUserId,
|
||||
sort: [{column: "score", direction: "desc"}]
|
||||
});
|
||||
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
|
||||
|
||||
// Request 2: Get top 10 with rank numbers
|
||||
const top10 = await api.read({
|
||||
sort: [{column: "score", direction: "desc"}],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
|
||||
|
||||
// Display:
|
||||
// "Your Rank: #156"
|
||||
// "Top Players:"
|
||||
// "#1 - Alice - 9999"
|
||||
// "#2 - Bob - 9876"
|
||||
// ...
|
||||
```
|
||||
|
||||
## Complete Example: Advanced Query
|
||||
|
||||
Combine all features for a complex query:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email", "score", "status"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "trial",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "score",
|
||||
"operator": "gte",
|
||||
"value": 100
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_activity",
|
||||
"sql": "last_login > NOW() - INTERVAL '7 days'"
|
||||
},
|
||||
{
|
||||
"name": "verified_email",
|
||||
"sql": "email_verified = true"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345",
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This query:
|
||||
- Selects specific columns
|
||||
- Filters for users with status "active" OR "trial"
|
||||
- AND score >= 100
|
||||
- Applies custom SQL conditions for recent activity and verified emails
|
||||
- Sorts by score (descending) then creation date (ascending)
|
||||
- Returns the row number of user "12345" in this filtered/sorted set
|
||||
- Returns 50 records starting from the first one
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Leaderboards - Get Current User's Rank
|
||||
|
||||
Get the current user's position and data (returns only their record):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "game_id",
|
||||
"operator": "eq",
|
||||
"value": "game123"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "current_user_id"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Tip:** For full leaderboards, make two requests:
|
||||
1. One with `fetch_row_number` to get user's rank
|
||||
2. One with `limit` and `offset` to get top players list
|
||||
|
||||
### 2. Multi-Status Search
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "pending"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "processing",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "shipped",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Advanced Date Filtering
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "this_month",
|
||||
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
|
||||
},
|
||||
{
|
||||
"name": "business_hours",
|
||||
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
|
||||
|
||||
1. **Never** directly interpolate user input into custom operator SQL
|
||||
2. Always validate and sanitize custom operator SQL on the backend
|
||||
3. Consider using a whitelist of allowed custom operators
|
||||
4. Use prepared statements or parameterized queries when possible
|
||||
5. Implement proper authorization checks before executing queries
|
||||
|
||||
Example of safe custom operator handling in Go:
|
||||
|
||||
```go
|
||||
// Whitelist of allowed custom operators
|
||||
allowedOperators := map[string]string{
|
||||
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
|
||||
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
|
||||
"premium_only": "subscription_level = 'premium'",
|
||||
}
|
||||
|
||||
// Validate custom operators from request
|
||||
for _, op := range req.Options.CustomOperators {
|
||||
if sql, ok := allowedOperators[op.Name]; ok {
|
||||
op.SQL = sql // Use whitelisted SQL
|
||||
} else {
|
||||
return errors.New("custom operator not allowed: " + op.Name)
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -214,6 +214,146 @@ Content-Type: application/json
|
||||
|
||||
```json
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
|
||||
|
||||
### Custom Operators
|
||||
|
||||
Add custom SQL conditions when needed:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "email_domain_filter",
|
||||
"sql": "LOWER(email) LIKE '%@example.com'"
|
||||
},
|
||||
{
|
||||
"name": "recent_records",
|
||||
"sql": "created_at > NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Custom operators are applied as additional WHERE conditions to your query.
|
||||
|
||||
### Fetch Row Number
|
||||
|
||||
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Returns ONLY the specified record with its position:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "John Doe",
|
||||
"score": 850,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 1000,
|
||||
"count": 1,
|
||||
"filtered": 1000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
|
||||
|
||||
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
|
||||
|
||||
## Preloading
|
||||
|
||||
Load related entities with custom configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email"],
|
||||
"preload": [
|
||||
{
|
||||
"relation": "posts",
|
||||
"columns": ["id", "title", "created_at"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "published"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"limit": 5
|
||||
},
|
||||
{
|
||||
"relation": "profile",
|
||||
"columns": ["bio", "website"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
Efficient pagination for large datasets:
|
||||
|
||||
### First Request (No Cursor)
|
||||
|
||||
```json
|
||||
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
|
||||
// Check permissions
|
||||
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify query options
|
||||
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
|
||||
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
|
||||
}
|
||||
|
||||
return nil
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
})
|
||||
|
||||
// Register an after-read hook (e.g., for data transformation)
|
||||
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
|
||||
})
|
||||
// Transform or filter results
|
||||
if users, ok := ctx.Result.([]User); ok {
|
||||
for i := range users {
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register a before-create hook (e.g., for validation)
|
||||
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
|
||||
// Validate data
|
||||
if user, ok := ctx.Data.(*User); ok {
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
// Add timestamps
|
||||
|
||||
143
pkg/resolvespec/filter_test.go
Normal file
143
pkg/resolvespec/filter_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestBuildFilterCondition tests the filter condition builder
|
||||
func TestBuildFilterCondition(t *testing.T) {
|
||||
h := &Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filter common.FilterOption
|
||||
expectedCondition string
|
||||
expectedArgsCount int
|
||||
}{
|
||||
{
|
||||
name: "Equal operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "eq",
|
||||
Value: "active",
|
||||
},
|
||||
expectedCondition: "status = ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Greater than operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "age",
|
||||
Operator: "gt",
|
||||
Value: 18,
|
||||
},
|
||||
expectedCondition: "age > ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "IN operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "in",
|
||||
Value: []string{"active", "pending"},
|
||||
},
|
||||
expectedCondition: "status IN (?)",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "LIKE operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "email",
|
||||
Operator: "like",
|
||||
Value: "%@example.com",
|
||||
},
|
||||
expectedCondition: "email LIKE ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
condition, args := h.buildFilterCondition(tt.filter)
|
||||
|
||||
if condition != tt.expectedCondition {
|
||||
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
|
||||
}
|
||||
|
||||
if len(args) != tt.expectedArgsCount {
|
||||
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
|
||||
}
|
||||
|
||||
// Note: Skip value comparison for slices as they can't be compared with ==
|
||||
// The important part is that args are populated correctly
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestORGrouping tests that consecutive OR filters are properly grouped
|
||||
func TestORGrouping(t *testing.T) {
|
||||
// This is a conceptual test - in practice we'd need a mock SelectQuery
|
||||
// to verify the actual SQL grouping behavior
|
||||
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
|
||||
{Column: "age", Operator: "gte", Value: 18},
|
||||
}
|
||||
|
||||
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
|
||||
// The first three filters should be grouped together
|
||||
// The fourth filter should be separate with AND
|
||||
|
||||
// Count OR groups
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected one OR group
|
||||
if orGroupCount != 1 {
|
||||
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "priority", Operator: "eq", Value: "high"},
|
||||
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
|
||||
}
|
||||
|
||||
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
|
||||
// Should have two OR groups
|
||||
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected two OR groups
|
||||
if orGroupCount != 2 {
|
||||
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -280,10 +280,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
for _, filter := range options.Filters {
|
||||
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
|
||||
query = h.applyFilter(query, filter)
|
||||
// Apply filters with proper grouping for OR logic
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
@@ -381,24 +384,105 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
// Handle FetchRowNumber if requested
|
||||
var rowNumber *int64
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Build ROW_NUMBER window function SQL
|
||||
rowNumberSQL := "ROW_NUMBER() OVER ("
|
||||
if len(options.Sort) > 0 {
|
||||
rowNumberSQL += "ORDER BY "
|
||||
for i, sort := range options.Sort {
|
||||
if i > 0 {
|
||||
rowNumberSQL += ", "
|
||||
}
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
|
||||
}
|
||||
}
|
||||
rowNumberSQL += ")"
|
||||
|
||||
// Create a query to fetch the row number using a subquery approach
|
||||
// We'll select the PK and row_number, then filter by the target ID
|
||||
type RowNumResult struct {
|
||||
RowNum int64 `bun:"row_num"`
|
||||
}
|
||||
|
||||
rowNumQuery := h.db.NewSelect().Table(tableName).
|
||||
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
|
||||
Column(pkName)
|
||||
|
||||
// Apply the same filters as the main query
|
||||
for _, filter := range options.Filters {
|
||||
rowNumQuery = h.applyFilter(rowNumQuery, filter)
|
||||
}
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
rowNumQuery = rowNumQuery.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Filter for the specific ID we want the row number for
|
||||
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
|
||||
|
||||
// Execute query to get row number
|
||||
var result RowNumResult
|
||||
if err := rowNumQuery.Scan(ctx, &result); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// Build filter description for error message
|
||||
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
|
||||
if len(options.CustomOperators) > 0 {
|
||||
customOps := make([]string, 0, len(options.CustomOperators))
|
||||
for _, op := range options.CustomOperators {
|
||||
customOps = append(customOps, op.SQL)
|
||||
}
|
||||
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
|
||||
}
|
||||
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
|
||||
} else {
|
||||
logger.Warn("Error fetching row number: %v", err)
|
||||
}
|
||||
} else {
|
||||
rowNumber = &result.RowNum
|
||||
logger.Debug("Found row number: %d", *rowNumber)
|
||||
}
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
|
||||
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
|
||||
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
var result interface{}
|
||||
if id != "" {
|
||||
logger.Debug("Querying single record with ID: %s", id)
|
||||
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
|
||||
// Single record query - either by URL ID or FetchRowNumber
|
||||
var targetID string
|
||||
if id != "" {
|
||||
targetID = id
|
||||
logger.Debug("Querying single record with URL ID: %s", id)
|
||||
} else {
|
||||
targetID = *options.FetchRowNumber
|
||||
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
|
||||
}
|
||||
|
||||
// For single record, create a new pointer to the struct type
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
logger.Error("Error querying record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||
@@ -418,20 +502,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
logger.Info("Successfully retrieved records")
|
||||
|
||||
// Build metadata
|
||||
limit := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
offset := 0
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
count := int64(total)
|
||||
|
||||
// When FetchRowNumber is used, we only return 1 record
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
count = 1
|
||||
// Set the fetched row number on the record
|
||||
if rowNumber != nil {
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
|
||||
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
}
|
||||
} else {
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Set row numbers on records if RowNumber field exists
|
||||
// Only for multiple records (not when fetching single record)
|
||||
h.setRowNumbersOnRecords(result, offset)
|
||||
}
|
||||
|
||||
h.sendResponse(w, result, &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
RowNumber: rowNumber,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1303,29 +1406,161 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
h.sendResponse(w, recordToDelete, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// applyFilters applies all filters with proper grouping for OR logic
|
||||
// Groups consecutive OR filters together to ensure proper query precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (current or next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped WHERE clause
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build all conditions and collect args
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Single filter - no need for grouping
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
|
||||
// Multiple conditions - group with parentheses and OR
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return condition, args
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// Determine which method to use based on LogicOperator
|
||||
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
|
||||
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
default:
|
||||
return query
|
||||
}
|
||||
|
||||
// Apply filter with appropriate logic operator
|
||||
if useOrLogic {
|
||||
return query.WhereOr(condition, args...)
|
||||
}
|
||||
return query.Where(condition, args...)
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
@@ -1380,10 +1615,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -1703,6 +1944,51 @@ func toSnakeCase(s string) string {
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
if h.openAPIGenerator == nil {
|
||||
|
||||
@@ -216,9 +216,30 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -256,7 +277,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// POST route without ID
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -267,10 +288,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// POST route with ID
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -282,10 +304,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// GET route without ID
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -296,10 +319,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
// GET route with ID
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -311,9 +335,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
@@ -330,6 +356,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
@@ -355,8 +382,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with bunrouter
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes with bunrouter without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -377,8 +404,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -396,8 +423,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
|
||||
func ExampleWithGORMAndAuth(db *gorm.DB) {
|
||||
// Create handler using GORM
|
||||
_ = NewHandlerWithGORM(db)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup router with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Register models
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
|
||||
func ExampleWithBunAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup routes with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
|
||||
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun database adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler with Bun
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Create bunrouter
|
||||
_ = bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with authentication
|
||||
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
@@ -549,8 +549,30 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// If ID is provided, filter by ID
|
||||
if id != "" {
|
||||
// Handle FetchRowNumber before applying ID filter
|
||||
// This must happen before the query to get the row position, then filter by PK
|
||||
var fetchedRowNumber *int64
|
||||
var fetchRowNumberPKValue string
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
fetchRowNumberPKValue = *options.FetchRowNumber
|
||||
|
||||
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
|
||||
if err != nil {
|
||||
logger.Error("Failed to fetch row number: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
|
||||
return
|
||||
}
|
||||
|
||||
fetchedRowNumber = &rowNum
|
||||
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||
|
||||
// Now filter the main query to this specific primary key
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), fetchRowNumberPKValue)
|
||||
} else if id != "" {
|
||||
// If ID is provided (and not FetchRowNumber), filter by ID
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
||||
|
||||
@@ -730,7 +752,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Set row numbers on each record if the model has a RowNumber field
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
// If FetchRowNumber was used, set the fetched row number instead of offset-based
|
||||
if fetchedRowNumber != nil {
|
||||
// FetchRowNumber: set the actual row position on the record
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
} else {
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
@@ -740,21 +769,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// Fetch row number for a specific record if requested
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
pkValue := *options.FetchRowNumber
|
||||
|
||||
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to fetch row number: %v", err)
|
||||
// Don't fail the entire request, just log the warning
|
||||
} else {
|
||||
metadata.RowNumber = &rowNum
|
||||
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
||||
}
|
||||
// If FetchRowNumber was used, also set it in metadata
|
||||
if fetchedRowNumber != nil {
|
||||
metadata.RowNumber = fetchedRowNumber
|
||||
logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
|
||||
}
|
||||
|
||||
// Execute AfterRead hooks
|
||||
@@ -2015,11 +2033,18 @@ func (h *Handler) processChildRelationsForField(
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTableNameForRelatedModel gets the table name for a related model
|
||||
// getTableNameForRelatedModel gets the table name for a related model.
|
||||
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
||||
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
||||
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := provider.TableName()
|
||||
if tableName != "" {
|
||||
if schema, table := h.parseTableName(tableName); schema != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schema, table)
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
}
|
||||
@@ -2264,10 +2289,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -2589,21 +2620,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clauses from filters
|
||||
whereClauses := make([]string, 0)
|
||||
for i := range options.Filters {
|
||||
filter := &options.Filters[i]
|
||||
whereClause := h.buildFilterSQL(filter, tableName)
|
||||
if whereClause != "" {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
||||
}
|
||||
}
|
||||
|
||||
// Combine WHERE clauses
|
||||
whereSQL := ""
|
||||
if len(whereClauses) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
||||
}
|
||||
// Build WHERE clause from filters with proper OR grouping
|
||||
whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
|
||||
|
||||
// Add custom SQL WHERE if provided
|
||||
if options.CustomSQLWhere != "" {
|
||||
@@ -2651,19 +2669,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
|
||||
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
||||
whereInfo := "none"
|
||||
if whereSQL != "" {
|
||||
whereInfo = whereSQL
|
||||
}
|
||||
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
|
||||
// Groups consecutive OR filters together to ensure proper SQL precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
|
||||
if len(filters) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var groups []string
|
||||
i := 0
|
||||
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []string{}
|
||||
|
||||
// Add current filter
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
|
||||
// Collect remaining OR filters
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
filterSQL := h.buildFilterSQL(&filters[j], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
j++
|
||||
}
|
||||
|
||||
// Group OR filters with parentheses
|
||||
if len(orGroup) > 0 {
|
||||
if len(orGroup) == 1 {
|
||||
groups = append(groups, orGroup[0])
|
||||
} else {
|
||||
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
|
||||
}
|
||||
}
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
groups = append(groups, filterSQL)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(groups, " AND ")
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
|
||||
@@ -280,9 +280,30 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -313,7 +334,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// GET and POST for /{schema}/{entity}
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -324,9 +345,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -337,10 +359,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -352,9 +375,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -366,9 +390,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -380,9 +405,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -394,9 +420,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -408,10 +435,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// Metadata endpoint
|
||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -422,9 +450,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
@@ -441,6 +471,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
@@ -466,8 +497,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
@@ -487,7 +518,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
|
||||
527
pkg/security/OAUTH2.md
Normal file
527
pkg/security/OAUTH2.md
Normal file
@@ -0,0 +1,527 @@
|
||||
# OAuth2 Authentication Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
|
||||
|
||||
## Features
|
||||
|
||||
- **Universal OAuth2 Support**: Works with any OAuth2 provider
|
||||
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
|
||||
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
|
||||
- **Custom Providers**: Easy configuration for any OAuth2 service
|
||||
- **Session Management**: Database-backed session storage
|
||||
- **Token Refresh**: Automatic token refresh support
|
||||
- **State Validation**: Built-in CSRF protection
|
||||
- **User Auto-Creation**: Automatically creates users on first login
|
||||
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Database Setup
|
||||
|
||||
```sql
|
||||
-- Run the schema from database_schema.sql
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
-- OAuth2 stored procedures (7 functions)
|
||||
-- See database_schema.sql for full implementation
|
||||
```
|
||||
|
||||
### 2. Google OAuth2
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
// Create authenticator
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"your-google-client-id",
|
||||
"your-google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Login route - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback route - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. GitHub OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Same routes pattern as Google
|
||||
router.HandleFunc("/auth/github/login", ...)
|
||||
router.HandleFunc("/auth/github/callback", ...)
|
||||
```
|
||||
|
||||
### 4. Microsoft OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewMicrosoftAuthenticator(
|
||||
"your-microsoft-client-id",
|
||||
"your-microsoft-client-secret",
|
||||
"http://localhost:8080/auth/microsoft/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Facebook OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewFacebookAuthenticator(
|
||||
"your-facebook-client-id",
|
||||
"your-facebook-client-secret",
|
||||
"http://localhost:8080/auth/facebook/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
## Custom OAuth2 Provider
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
DB: db,
|
||||
ProviderName: "custom",
|
||||
|
||||
// Optional: Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
return &security.UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Protected Routes
|
||||
|
||||
```go
|
||||
// Create security provider
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
|
||||
// Apply middleware to protected routes
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(security.NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
```
|
||||
|
||||
## Token Refresh
|
||||
|
||||
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
|
||||
- Store the refresh token securely on the client side
|
||||
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
|
||||
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
|
||||
|
||||
## Logout
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
|
||||
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
```
|
||||
|
||||
## Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single DatabaseAuthenticator with ALL OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders() // ["google", "github"]
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// Use same authenticator for protected routes - works for ALL providers
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### OAuth2Config Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| ClientID | string | OAuth2 client ID from provider |
|
||||
| ClientSecret | string | OAuth2 client secret |
|
||||
| RedirectURL | string | Callback URL registered with provider |
|
||||
| Scopes | []string | OAuth2 scopes to request |
|
||||
| AuthURL | string | Provider's authorization endpoint |
|
||||
| TokenURL | string | Provider's token endpoint |
|
||||
| UserInfoURL | string | Provider's user info endpoint |
|
||||
| DB | *sql.DB | Database connection for sessions |
|
||||
| UserInfoParser | func | Custom parser for user info (optional) |
|
||||
| StateValidator | func | Custom state validator (optional) |
|
||||
| ProviderName | string | Provider name for logging (optional) |
|
||||
|
||||
## User Info Parsing
|
||||
|
||||
The default parser extracts these standard fields:
|
||||
- `sub` → RemoteID
|
||||
- `email` → Email, UserName
|
||||
- `name` → UserName
|
||||
- `login` → UserName (GitHub)
|
||||
|
||||
Custom parser example:
|
||||
|
||||
```go
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
// Extract custom fields
|
||||
ctx := &security.UserContext{
|
||||
UserName: userInfo["preferred_username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["sub"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo, // Store all claims
|
||||
}
|
||||
|
||||
// Add custom roles based on provider data
|
||||
if groups, ok := userInfo["groups"].([]interface{}); ok {
|
||||
for _, g := range groups {
|
||||
ctx.Roles = append(ctx.Roles, g.(string))
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
```go
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Secure: true, // Only send over HTTPS
|
||||
HttpOnly: true, // Prevent XSS access
|
||||
SameSite: http.SameSiteLaxMode, // CSRF protection
|
||||
})
|
||||
```
|
||||
|
||||
2. **Store secrets securely**
|
||||
```go
|
||||
clientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
```
|
||||
|
||||
3. **Validate redirect URLs**
|
||||
- Only register trusted redirect URLs with OAuth2 providers
|
||||
- Never accept redirect URL from request parameters
|
||||
|
||||
5. **Session expiration**
|
||||
- OAuth2 sessions automatically expire based on token expiry
|
||||
- Clean up expired sessions periodically:
|
||||
```sql
|
||||
DELETE FROM user_sessions WHERE expires_at < NOW();
|
||||
```
|
||||
|
||||
4. **State parameter**
|
||||
- Automatically generated with cryptographic randomness
|
||||
- One-time use and expires after 10 minutes
|
||||
- Prevents CSRF attacks
|
||||
|
||||
## Implementation Details
|
||||
|
||||
All database operations use stored procedures for consistency and security:
|
||||
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
|
||||
- `resolvespec_oauth_createsession` - Create OAuth2 session
|
||||
- `resolvespec_oauth_getsession` - Validate and retrieve session
|
||||
- `resolvespec_oauth_deletesession` - Logout/delete session
|
||||
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser` - Get user data by ID
|
||||
|
||||
## Provider Setup Guides
|
||||
|
||||
### Google
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select existing
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
|
||||
6. Copy Client ID and Client Secret
|
||||
|
||||
### GitHub
|
||||
|
||||
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
|
||||
2. Click "New OAuth App"
|
||||
3. Set Homepage URL: `http://localhost:8080`
|
||||
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
|
||||
5. Copy Client ID and Client Secret
|
||||
|
||||
### Microsoft
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Register new application in Azure AD
|
||||
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
|
||||
4. Create client secret
|
||||
5. Copy Application (client) ID and secret value
|
||||
|
||||
### Facebook
|
||||
|
||||
1. Go to [Facebook Developers](https://developers.facebook.com/)
|
||||
2. Create new app
|
||||
3. Add Facebook Login product
|
||||
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
|
||||
5. Copy App ID and App Secret
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "redirect_uri_mismatch" error
|
||||
- Ensure the redirect URL in code matches exactly with provider configuration
|
||||
- Include protocol (http/https), domain, port, and path
|
||||
|
||||
### "invalid_client" error
|
||||
- Verify Client ID and Client Secret are correct
|
||||
- Check if credentials are for the correct environment (dev/prod)
|
||||
|
||||
### "invalid_grant" error during token exchange
|
||||
- State parameter validation failed
|
||||
- Token might have expired
|
||||
- Check server time synchronization
|
||||
|
||||
### User not created after successful OAuth2 login
|
||||
- Check database constraints (username/email unique)
|
||||
- Verify UserInfoParser is extracting required fields
|
||||
- Check database logs for constraint violations
|
||||
|
||||
## Testing
|
||||
|
||||
```go
|
||||
func TestOAuth2Flow(t *testing.T) {
|
||||
// Mock database
|
||||
db, mock, _ := sqlmock.New()
|
||||
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Test state generation
|
||||
state, err := oauth2Auth.GenerateState()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
|
||||
// Test auth URL generation
|
||||
authURL := oauth2Auth.GetAuthURL(state)
|
||||
assert.Contains(t, authURL, "accounts.google.com")
|
||||
assert.Contains(t, authURL, state)
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### DatabaseAuthenticator with OAuth2
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
|
||||
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
|
||||
| OAuth2GenerateState() | Generates random state for CSRF protection |
|
||||
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
|
||||
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
|
||||
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
|
||||
| Login(ctx, req) | Standard username/password login |
|
||||
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
|
||||
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
|
||||
|
||||
### Pre-configured Constructors
|
||||
|
||||
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
|
||||
|
||||
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
|
||||
|
||||
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
|
||||
|
||||
## Examples
|
||||
|
||||
Complete working examples available in `oauth2_examples.go`:
|
||||
- Basic Google OAuth2
|
||||
- GitHub OAuth2
|
||||
- Custom provider
|
||||
- Multi-provider setup
|
||||
- Token refresh
|
||||
- Logout flow
|
||||
- Complete integration with security middleware
|
||||
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# OAuth2 Refresh Token - Quick Reference
|
||||
|
||||
## Quick Setup (3 Steps)
|
||||
|
||||
### 1. Initialize Authenticator
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"client-id",
|
||||
"client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. OAuth2 Login Flow
|
||||
```go
|
||||
// Login - Redirect to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback - Store tokens
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, _ := auth.OAuth2HandleCallback(
|
||||
r.Context(),
|
||||
"google",
|
||||
r.URL.Query().Get("code"),
|
||||
r.URL.Query().Get("state"),
|
||||
)
|
||||
|
||||
// Save refresh_token on client
|
||||
// loginResp.RefreshToken - Store this securely!
|
||||
// loginResp.Token - Session token for API calls
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Refresh Endpoint
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Multi-Provider Example
|
||||
|
||||
```go
|
||||
// Configure multiple providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google",
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "github",
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
})
|
||||
|
||||
// Refresh with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Client-Side JavaScript
|
||||
|
||||
```javascript
|
||||
// Automatic token refresh on 401
|
||||
async function apiCall(url) {
|
||||
let response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
// Token expired - refresh it
|
||||
if (response.status === 401) {
|
||||
await refreshToken();
|
||||
|
||||
// Retry request with new token
|
||||
response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function refreshToken() {
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: localStorage.getItem('refresh_token'),
|
||||
provider: localStorage.getItem('provider')
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Methods
|
||||
|
||||
| Method | Parameters | Returns |
|
||||
|--------|-----------|---------|
|
||||
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
|
||||
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
|
||||
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
|
||||
| `OAuth2GenerateState` | none | `string, error` |
|
||||
| `OAuth2GetProviders` | none | `[]string` |
|
||||
|
||||
---
|
||||
|
||||
## LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token for API calls
|
||||
RefreshToken string // Refresh token (store securely)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until token expires
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser(user_id)` - Get user data
|
||||
|
||||
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
|
||||
|
||||
---
|
||||
|
||||
## Common Errors
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
|
||||
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
|
||||
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
|
||||
|
||||
---
|
||||
|
||||
## Security Checklist
|
||||
|
||||
- [ ] Use HTTPS for all OAuth2 endpoints
|
||||
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
|
||||
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
|
||||
- [ ] Implement rate limiting on refresh endpoint
|
||||
- [ ] Log refresh attempts for audit
|
||||
- [ ] Rotate tokens on refresh
|
||||
- [ ] Revoke old sessions after successful refresh
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# 1. Login and get refresh token
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow OAuth2 flow, get refresh_token from callback response
|
||||
|
||||
# 2. Refresh token
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
|
||||
|
||||
# 3. Use new token
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pre-configured Providers
|
||||
|
||||
```go
|
||||
// Google
|
||||
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// GitHub
|
||||
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Microsoft
|
||||
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Facebook
|
||||
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// All providers at once
|
||||
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
|
||||
"google": {...},
|
||||
"github": {...},
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Provider-Specific Notes
|
||||
|
||||
### Google
|
||||
- Add `access_type=offline` to get refresh token
|
||||
- Add `prompt=consent` to force consent screen
|
||||
```go
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
### GitHub
|
||||
- Refresh tokens not always provided
|
||||
- May need to request `offline_access` scope
|
||||
|
||||
### Microsoft
|
||||
- Use `offline_access` scope for refresh token
|
||||
|
||||
### Facebook
|
||||
- Tokens expire after 60 days by default
|
||||
- Check app settings for token expiration policy
|
||||
|
||||
---
|
||||
|
||||
## Complete Example
|
||||
|
||||
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
|
||||
|
||||
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.
|
||||
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,495 @@
|
||||
# OAuth2 Refresh Token Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
|
||||
|
||||
## Implementation Status: ✅ COMPLETE
|
||||
|
||||
### Components Implemented
|
||||
|
||||
1. **✅ Database Schema** - Tables and stored procedures
|
||||
2. **✅ Go Methods** - OAuth2RefreshToken implementation
|
||||
3. **✅ Thread Safety** - Mutex protection for provider map
|
||||
4. **✅ Examples** - Working code examples
|
||||
5. **✅ Documentation** - Complete API reference
|
||||
|
||||
---
|
||||
|
||||
## 1. Database Schema
|
||||
|
||||
### Tables Modified
|
||||
|
||||
```sql
|
||||
-- user_sessions table with OAuth2 token fields
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT, -- OAuth2 access token
|
||||
refresh_token TEXT, -- OAuth2 refresh token
|
||||
token_type VARCHAR(50), -- "Bearer", etc.
|
||||
auth_provider VARCHAR(50) -- "google", "github", etc.
|
||||
);
|
||||
```
|
||||
|
||||
### Stored Procedures
|
||||
|
||||
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
|
||||
- Gets OAuth2 session data by refresh token
|
||||
- Returns: `{user_id, access_token, token_type, expiry}`
|
||||
- Location: `database_schema.sql:714`
|
||||
|
||||
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
|
||||
- Updates session with new tokens after refresh
|
||||
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
|
||||
- Location: `database_schema.sql:752`
|
||||
|
||||
**`resolvespec_oauth_getuser(p_user_id)`**
|
||||
- Gets user data by ID for building UserContext
|
||||
- Location: `database_schema.sql:791`
|
||||
|
||||
---
|
||||
|
||||
## 2. Go Implementation
|
||||
|
||||
### Method Signature
|
||||
|
||||
```go
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
|
||||
ctx context.Context,
|
||||
refreshToken string,
|
||||
providerName string,
|
||||
) (*LoginResponse, error)
|
||||
```
|
||||
|
||||
**Location:** `pkg/security/oauth2_methods.go:375`
|
||||
|
||||
### Implementation Flow
|
||||
|
||||
```
|
||||
1. Validate provider exists
|
||||
├─ getOAuth2Provider(providerName) with RLock
|
||||
└─ Return error if provider not configured
|
||||
|
||||
2. Get session from database
|
||||
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
|
||||
└─ Parse session data {user_id, access_token, token_type, expiry}
|
||||
|
||||
3. Refresh token with OAuth2 provider
|
||||
├─ Create oauth2.Token from stored data
|
||||
├─ Use provider.config.TokenSource(ctx, oldToken)
|
||||
└─ Call tokenSource.Token() to get new token
|
||||
|
||||
4. Generate new session token
|
||||
└─ Use OAuth2GenerateState() for secure random token
|
||||
|
||||
5. Update database
|
||||
├─ Call resolvespec_oauth_updaterefreshtoken()
|
||||
└─ Store new session_token, access_token, refresh_token
|
||||
|
||||
6. Get user data
|
||||
├─ Call resolvespec_oauth_getuser(user_id)
|
||||
└─ Build UserContext
|
||||
|
||||
7. Return LoginResponse
|
||||
└─ {Token, RefreshToken, User, ExpiresIn}
|
||||
```
|
||||
|
||||
### Thread Safety
|
||||
|
||||
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
|
||||
|
||||
```go
|
||||
type DatabaseAuthenticator struct {
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
|
||||
}
|
||||
|
||||
// Read operations use RLock
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
// ... access map
|
||||
}
|
||||
|
||||
// Write operations use Lock
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
defer a.oauth2ProvidersMutex.Unlock()
|
||||
// ... modify map
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Usage Examples
|
||||
|
||||
### Single Provider (Google)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create Google OAuth2 authenticator
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Token refresh endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token (provider name defaults to "google")
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single authenticator with multiple OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Refresh endpoint with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh with specific provider
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
### Client-Side Usage
|
||||
|
||||
```javascript
|
||||
// JavaScript client example
|
||||
async function refreshAccessToken() {
|
||||
const refreshToken = localStorage.getItem('refresh_token');
|
||||
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
|
||||
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: refreshToken,
|
||||
provider: provider
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
|
||||
// Store new tokens
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
|
||||
console.log('Token refreshed successfully');
|
||||
return data.token;
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
|
||||
// Automatically refresh token when API returns 401
|
||||
async function apiCall(endpoint) {
|
||||
let response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
// Token expired - try refresh
|
||||
const newToken = await refreshAccessToken();
|
||||
|
||||
// Retry with new token
|
||||
response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + newToken
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. API Reference
|
||||
|
||||
### DatabaseAuthenticator Methods
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
|
||||
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
|
||||
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
|
||||
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
|
||||
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
|
||||
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
|
||||
|
||||
### LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token
|
||||
RefreshToken string // New refresh token (may be same as input)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until expiration
|
||||
}
|
||||
|
||||
type UserContext struct {
|
||||
UserID int // Database user ID
|
||||
UserName string // Username
|
||||
Email string // Email address
|
||||
UserLevel int // Permission level
|
||||
SessionID string // Session token
|
||||
RemoteID string // OAuth2 provider user ID
|
||||
Roles []string // User roles
|
||||
Claims map[string]any // Additional claims
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Important Notes
|
||||
|
||||
### Provider Configuration
|
||||
|
||||
**For Google:** Add `access_type=offline` to get refresh token on first login:
|
||||
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
|
||||
// When generating auth URL, add access_type parameter
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
|
||||
|
||||
### Token Storage
|
||||
|
||||
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
|
||||
- Never log refresh tokens
|
||||
- Refresh tokens are long-lived (days/months depending on provider)
|
||||
- Access tokens are short-lived (minutes/hours)
|
||||
|
||||
### Error Handling
|
||||
|
||||
Common errors:
|
||||
- `"invalid or expired refresh token"` - Token expired or revoked
|
||||
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
|
||||
- `"failed to refresh token with provider"` - Provider rejected refresh request
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS** for token transmission
|
||||
2. **Store refresh tokens securely** on client
|
||||
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
|
||||
4. **Implement token rotation** - issue new refresh token on each refresh
|
||||
5. **Revoke old tokens** after successful refresh
|
||||
6. **Rate limit** refresh endpoints
|
||||
7. **Log refresh attempts** for audit trail
|
||||
|
||||
---
|
||||
|
||||
## 6. Testing
|
||||
|
||||
### Manual Test Flow
|
||||
|
||||
1. **Initial Login:**
|
||||
```bash
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow redirect to Google
|
||||
# Returns to callback with LoginResponse containing refresh_token
|
||||
```
|
||||
|
||||
2. **Wait for Token Expiry (or manually expire in DB)**
|
||||
|
||||
3. **Refresh Token:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"provider": "google"
|
||||
}'
|
||||
|
||||
# Response:
|
||||
{
|
||||
"token": "sess_abc123...",
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"user": {
|
||||
"user_id": 1,
|
||||
"user_name": "john_doe",
|
||||
"email": "john@example.com",
|
||||
"session_id": "sess_abc123..."
|
||||
},
|
||||
"expires_in": 3600
|
||||
}
|
||||
```
|
||||
|
||||
4. **Use New Token:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
### Database Verification
|
||||
|
||||
```sql
|
||||
-- Check session with refresh token
|
||||
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
|
||||
FROM user_sessions
|
||||
WHERE refresh_token = 'ya29.a0AfH6SMB...';
|
||||
|
||||
-- Verify token was updated after refresh
|
||||
SELECT session_token, access_token, refresh_token,
|
||||
expires_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = 1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Troubleshooting
|
||||
|
||||
### "Refresh token not found or expired"
|
||||
|
||||
**Cause:** Refresh token doesn't exist in database or session expired
|
||||
|
||||
**Solution:**
|
||||
- Check if initial OAuth2 login stored refresh token
|
||||
- Verify provider returns refresh token (some require `access_type=offline`)
|
||||
- Check session hasn't been deleted from database
|
||||
|
||||
### "Failed to refresh token with provider"
|
||||
|
||||
**Cause:** OAuth2 provider rejected the refresh request
|
||||
|
||||
**Possible reasons:**
|
||||
- Refresh token was revoked by user
|
||||
- OAuth2 app credentials changed
|
||||
- Network connectivity issues
|
||||
- Provider rate limiting
|
||||
|
||||
**Solution:**
|
||||
- Re-authenticate user (full OAuth2 flow)
|
||||
- Check provider dashboard for app status
|
||||
- Verify client credentials are correct
|
||||
|
||||
### "OAuth2 provider 'xxx' not found"
|
||||
|
||||
**Cause:** Provider not registered with `WithOAuth2()`
|
||||
|
||||
**Solution:**
|
||||
```go
|
||||
// Make sure provider is configured
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google", // This name must match refresh call
|
||||
// ... other config
|
||||
})
|
||||
|
||||
// Then use same name in refresh
|
||||
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Complete Working Example
|
||||
|
||||
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
OAuth2 refresh token functionality is **production-ready** with:
|
||||
|
||||
- ✅ Complete database schema with stored procedures
|
||||
- ✅ Thread-safe Go implementation with mutex protection
|
||||
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
|
||||
- ✅ Comprehensive error handling
|
||||
- ✅ Working code examples
|
||||
- ✅ Full API documentation
|
||||
- ✅ Security best practices implemented
|
||||
|
||||
**No additional implementation needed - feature is complete and functional.**
|
||||
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Passkey Authentication Quick Reference
|
||||
|
||||
## Overview
|
||||
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
|
||||
|
||||
## Setup
|
||||
|
||||
### Database Schema
|
||||
Run the passkey SQL schema (in database_schema.sql):
|
||||
- Creates `user_passkey_credentials` table
|
||||
- Adds stored procedures for passkey operations
|
||||
|
||||
### Go Code
|
||||
```go
|
||||
// Create passkey provider
|
||||
passkeyProvider := security.NewDatabasePasskeyProvider(db,
|
||||
security.DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
Timeout: 60000,
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
auth := security.NewDatabaseAuthenticatorWithOptions(db,
|
||||
security.DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Or add passkey to existing authenticator
|
||||
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
```
|
||||
|
||||
## Registration Flow
|
||||
|
||||
### Backend - Step 1: Begin Registration
|
||||
```go
|
||||
options, err := auth.BeginPasskeyRegistration(ctx,
|
||||
security.PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Create Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
|
||||
// Create credential
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send credential back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Registration
|
||||
```go
|
||||
credential, err := auth.CompletePasskeyRegistration(ctx,
|
||||
security.PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
```
|
||||
|
||||
## Authentication Flow
|
||||
|
||||
### Backend - Step 1: Begin Authentication
|
||||
```go
|
||||
options, err := auth.BeginPasskeyAuthentication(ctx,
|
||||
security.PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional for resident key
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Get Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
|
||||
// Get credential
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send assertion back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Authentication
|
||||
```go
|
||||
loginResponse, err := auth.LoginWithPasskey(ctx,
|
||||
security.PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
// Returns session token and user info
|
||||
```
|
||||
|
||||
## Credential Management
|
||||
|
||||
### List Credentials
|
||||
```go
|
||||
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
|
||||
```
|
||||
|
||||
### Update Credential Name
|
||||
```go
|
||||
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
|
||||
```
|
||||
|
||||
### Delete Credential
|
||||
```go
|
||||
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
|
||||
```
|
||||
|
||||
## HTTP Endpoints Example
|
||||
|
||||
### POST /api/passkey/register/begin
|
||||
Request: `{user_id, username, display_name}`
|
||||
Response: PasskeyRegistrationOptions
|
||||
|
||||
### POST /api/passkey/register/complete
|
||||
Request: `{user_id, response, credential_name}`
|
||||
Response: PasskeyCredential
|
||||
|
||||
### POST /api/passkey/login/begin
|
||||
Request: `{username}` (optional)
|
||||
Response: PasskeyAuthenticationOptions
|
||||
|
||||
### POST /api/passkey/login/complete
|
||||
Request: `{response}`
|
||||
Response: LoginResponse with session token
|
||||
|
||||
### GET /api/passkey/credentials
|
||||
Response: Array of PasskeyCredential
|
||||
|
||||
### DELETE /api/passkey/credentials/{id}
|
||||
Request: `{credential_id}`
|
||||
Response: 204 No Content
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_passkey_store_credential` - Store new credential
|
||||
- `resolvespec_passkey_get_credential` - Get credential by ID
|
||||
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
|
||||
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
|
||||
- `resolvespec_passkey_delete_credential` - Delete credential
|
||||
- `resolvespec_passkey_update_name` - Update credential name
|
||||
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
|
||||
|
||||
## Security Features
|
||||
|
||||
- **Clone Detection**: Sign counter validation detects credential cloning
|
||||
- **Attestation Support**: Stores attestation type (none, indirect, direct)
|
||||
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
|
||||
- **Backup State**: Tracks if credential is backed up/synced
|
||||
- **User Verification**: Supports preferred/required user verification
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
|
||||
|
||||
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
|
||||
|
||||
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
|
||||
|
||||
4. **Browser Support**: Check browser compatibility for WebAuthn API.
|
||||
|
||||
5. **Relying Party ID**: Must match your domain exactly.
|
||||
|
||||
## Client-Side Helper Functions
|
||||
|
||||
```javascript
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests: `go test -v ./pkg/security -run Passkey`
|
||||
|
||||
All passkey functionality includes comprehensive tests using sqlmock.
|
||||
@@ -7,15 +7,16 @@
|
||||
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
|
||||
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// OR: auth := security.NewHeaderAuthenticator()
|
||||
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
|
||||
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
// Step 2: Combine providers
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
|
||||
// Step 3: Setup and apply middleware
|
||||
securityList := security.SetupSecurityProvider(handler, provider)
|
||||
securityList, _ := security.SetupSecurityProvider(handler, provider)
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```
|
||||
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```go
|
||||
// DatabaseAuthenticator uses these stored procedures:
|
||||
resolvespec_login(jsonb) // Login with credentials
|
||||
resolvespec_register(jsonb) // Register new user
|
||||
resolvespec_logout(jsonb) // Invalidate session
|
||||
resolvespec_session(text, text) // Validate session token
|
||||
resolvespec_session_update(text, jsonb) // Update activity timestamp
|
||||
@@ -502,10 +504,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
|
||||
|
||||
---
|
||||
|
||||
## Login/Logout Endpoints
|
||||
## Login/Logout/Register Endpoints
|
||||
|
||||
```go
|
||||
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
|
||||
// Register
|
||||
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.RegisterRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Check if provider supports registration
|
||||
registrable, ok := securityList.Provider().(security.Registrable)
|
||||
if !ok {
|
||||
http.Error(w, "Registration not supported", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := registrable.Register(r.Context(), req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}).Methods("POST")
|
||||
|
||||
// Login
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.LoginRequest
|
||||
@@ -707,6 +730,7 @@ meta, ok := security.GetUserMeta(ctx)
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
|
||||
| `examples.go` | Working provider implementations to copy |
|
||||
| `setup_example.go` | 6 complete integration examples |
|
||||
| `README.md` | Architecture overview and migration guide |
|
||||
|
||||
@@ -6,6 +6,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
|
||||
- ✅ **Interface-Based** - Type-safe providers instead of callbacks
|
||||
- ✅ **Login/Logout Support** - Built-in authentication lifecycle
|
||||
- ✅ **Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
|
||||
- ✅ **Composable** - Mix and match different providers
|
||||
- ✅ **No Global State** - Each handler has its own security configuration
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
@@ -212,6 +213,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// Note: Requires JWT library installation for token signing/verification
|
||||
```
|
||||
|
||||
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
|
||||
```go
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
// Use in-memory provider (for testing)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
|
||||
|
||||
// Or use database provider (for production)
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
// Requires: users table with totp fields, user_totp_backup_codes table
|
||||
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
|
||||
|
||||
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
// Supports: TOTP codes, backup codes, QR code generation
|
||||
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
|
||||
```
|
||||
|
||||
### Column Security Providers
|
||||
|
||||
**DatabaseColumnSecurityProvider** - Loads rules from database:
|
||||
@@ -334,7 +352,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Two-Factor Authentication (2FA)
|
||||
|
||||
### Overview
|
||||
|
||||
- **Optional per-user** - Enable/disable 2FA individually
|
||||
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
|
||||
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
|
||||
- **Backup codes** - One-time recovery codes with secure hashing
|
||||
- **Clock skew** - Handles time differences between client/server
|
||||
|
||||
### Setup
|
||||
|
||||
```go
|
||||
// 1. Wrap existing authenticator with 2FA support
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
|
||||
// 2. Use as normal authenticator
|
||||
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
### Enable 2FA for User
|
||||
|
||||
```go
|
||||
// 1. Initiate 2FA setup
|
||||
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
|
||||
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
|
||||
|
||||
// 2. User scans QR code with authenticator app
|
||||
// Display secret.QRCodeURL as QR code image
|
||||
|
||||
// 3. User enters verification code from app
|
||||
code := "123456" // From authenticator app
|
||||
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
|
||||
// 2FA is now enabled for this user
|
||||
|
||||
// 4. Store backup codes securely and show to user once
|
||||
// Display: secret.BackupCodes (10 codes)
|
||||
```
|
||||
|
||||
### Login Flow with 2FA
|
||||
|
||||
```go
|
||||
// 1. User provides credentials
|
||||
req := security.LoginRequest{
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(ctx, req)
|
||||
|
||||
// 2. Check if 2FA required
|
||||
if resp.Requires2FA {
|
||||
// Prompt user for 2FA code
|
||||
code := getUserInput() // From authenticator app or backup code
|
||||
|
||||
// 3. Login again with 2FA code
|
||||
req.TwoFactorCode = code
|
||||
resp, err = tfaAuth.Login(ctx, req)
|
||||
|
||||
// 4. Success - token is returned
|
||||
token := resp.Token
|
||||
}
|
||||
```
|
||||
|
||||
### Manage 2FA
|
||||
|
||||
```go
|
||||
// Disable 2FA
|
||||
err := tfaAuth.Disable2FA(userID)
|
||||
|
||||
// Regenerate backup codes
|
||||
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
|
||||
|
||||
// Check status
|
||||
has2FA, err := tfaProvider.Get2FAStatus(userID)
|
||||
```
|
||||
|
||||
### Custom 2FA Storage
|
||||
|
||||
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
|
||||
|
||||
```go
|
||||
// Uses PostgreSQL stored procedures for all operations
|
||||
db := setupDatabase()
|
||||
|
||||
// Run migrations from totp_database_schema.sql
|
||||
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
|
||||
// - Create user_totp_backup_codes table
|
||||
// - Create resolvespec_totp_* stored procedures
|
||||
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
```
|
||||
|
||||
**Option 2: Implement Custom Provider**
|
||||
|
||||
Implement `TwoFactorAuthProvider` for custom storage:
|
||||
|
||||
```go
|
||||
type DBTwoFactorProvider struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
// Store secret and hashed backup codes in database
|
||||
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
|
||||
secret, hashCodes(backupCodes), userID).Error
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var secret string
|
||||
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
|
||||
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```go
|
||||
config := &security.TwoFactorConfig{
|
||||
Algorithm: "SHA256", // SHA1, SHA256, SHA512
|
||||
Digits: 8, // 6 or 8
|
||||
Period: 30, // Seconds per code
|
||||
SkewWindow: 2, // Accept codes ±2 periods
|
||||
}
|
||||
|
||||
totp := security.NewTOTPGenerator(config)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
|
||||
```
|
||||
|
||||
### API Response Structure
|
||||
|
||||
```go
|
||||
// LoginResponse with 2FA
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
|
||||
User *UserContext `json:"user"`
|
||||
}
|
||||
|
||||
// TwoFactorSecret for setup
|
||||
type TwoFactorSecret struct {
|
||||
Secret string `json:"secret"` // Base32 encoded
|
||||
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
|
||||
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
|
||||
}
|
||||
|
||||
// UserContext includes 2FA status
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"`
|
||||
// ... other fields
|
||||
}
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
- **Store secrets encrypted** - Never store TOTP secrets in plain text
|
||||
- **Hash backup codes** - Use SHA-256 before storing
|
||||
- **Rate limit** - Limit 2FA verification attempts
|
||||
- **Require password** - Always verify password before disabling 2FA
|
||||
- **Show backup codes once** - Display only during setup/regeneration
|
||||
- **Log 2FA events** - Track enable/disable/failed attempts
|
||||
- **Mark codes as used** - Backup codes are single-use only
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,33 +7,48 @@ import (
|
||||
|
||||
// UserContext holds authenticated user information
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
|
||||
}
|
||||
|
||||
// LoginRequest contains credentials for login
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// RegisterRequest contains information for new user registration
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
UserLevel int `json:"user_level"`
|
||||
Roles []string `json:"roles"`
|
||||
Claims map[string]any `json:"claims"` // Additional registration data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata
|
||||
}
|
||||
|
||||
// LoginResponse contains the result of a login attempt
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// LogoutRequest contains information for logout
|
||||
@@ -55,6 +70,12 @@ type Authenticator interface {
|
||||
Authenticate(r *http.Request) (*UserContext, error)
|
||||
}
|
||||
|
||||
// Registrable allows providers to support user registration
|
||||
type Registrable interface {
|
||||
// Register creates a new user account
|
||||
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
|
||||
}
|
||||
|
||||
// ColumnSecurityProvider handles column-level security (masking/hiding)
|
||||
type ColumnSecurityProvider interface {
|
||||
// GetColumnSecurity loads column security rules for a user and entity
|
||||
|
||||
615
pkg/security/oauth2_examples.go
Normal file
615
pkg/security/oauth2_examples.go
Normal file
@@ -0,0 +1,615 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Example: OAuth2 Authentication with Google
|
||||
func ExampleOAuth2Google() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticator for Google
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Login endpoint - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback endpoint - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
// Return user info as JSON
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Authentication with GitHub
|
||||
func ExampleOAuth2GitHub() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Custom OAuth2 Provider
|
||||
func ExampleOAuth2Custom() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Custom OAuth2 provider configuration
|
||||
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
ProviderName: "custom-provider",
|
||||
|
||||
// Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
|
||||
// Extract custom fields from your provider
|
||||
return &UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Multi-Provider OAuth2 with Security Integration
|
||||
func ExampleOAuth2MultiProvider() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticators for multiple providers
|
||||
googleAuth := NewGoogleAuthenticator(
|
||||
"google-client-id",
|
||||
"google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
githubAuth := NewGitHubAuthenticator(
|
||||
"github-client-id",
|
||||
"github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create column and row security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google OAuth2 routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := googleAuth.OAuth2GenerateState()
|
||||
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// GitHub OAuth2 routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := githubAuth.OAuth2GenerateState()
|
||||
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Use Google auth for protected routes (or GitHub - both work)
|
||||
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected route with authentication
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 with Token Refresh
|
||||
func ExampleOAuth2TokenRefresh() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Refresh token endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Logout
|
||||
func ExampleOAuth2Logout() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err == nil {
|
||||
token = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
// Get user ID from session
|
||||
userCtx, err := oauth2Auth.Authenticate(r)
|
||||
if err == nil {
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: token,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Clear cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Logged out successfully"))
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Complete OAuth2 Integration with Database Setup
|
||||
func ExampleOAuth2Complete() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create tables (run once)
|
||||
setupOAuth2Tables(db)
|
||||
|
||||
// Create OAuth2 authenticator
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Public routes
|
||||
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
protectedRouter := router.PathPrefix("/").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
func setupOAuth2Tables(db *sql.DB) {
|
||||
// Create tables from database_schema.sql
|
||||
// This is a helper function - in production, use migrations
|
||||
ctx := context.Background()
|
||||
|
||||
// Create users table if not exists
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
|
||||
// Create user_sessions table (used for both regular and OAuth2 sessions)
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
}
|
||||
|
||||
// Example: All OAuth2 Providers at Once
|
||||
func ExampleOAuth2AllProviders() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create authenticator with ALL OAuth2 providers
|
||||
auth := NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "microsoft-client-id",
|
||||
ClientSecret: "microsoft-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "facebook-client-id",
|
||||
ClientSecret: "facebook-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/facebook/callback",
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders()
|
||||
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Microsoft routes
|
||||
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Facebook routes
|
||||
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Create security list for protected routes
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected routes work for ALL OAuth2 providers + regular sessions
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
579
pkg/security/oauth2_methods.go
Normal file
579
pkg/security/oauth2_methods.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth2Config contains configuration for OAuth2 authentication
|
||||
type OAuth2Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURL string
|
||||
Scopes []string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
ProviderName string
|
||||
|
||||
// Optional: Custom user info parser
|
||||
// If not provided, will use standard claims (sub, email, name)
|
||||
UserInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
}
|
||||
|
||||
// OAuth2Provider holds configuration and state for a single OAuth2 provider
|
||||
type OAuth2Provider struct {
|
||||
config *oauth2.Config
|
||||
userInfoURL string
|
||||
userInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
providerName string
|
||||
states map[string]time.Time // state -> expiry time
|
||||
statesMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// WithOAuth2 configures OAuth2 support for the DatabaseAuthenticator
|
||||
// Can be called multiple times to add multiple OAuth2 providers
|
||||
// Returns the same DatabaseAuthenticator instance for method chaining
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) *DatabaseAuthenticator {
|
||||
if cfg.ProviderName == "" {
|
||||
cfg.ProviderName = "oauth2"
|
||||
}
|
||||
|
||||
if cfg.UserInfoParser == nil {
|
||||
cfg.UserInfoParser = defaultOAuth2UserInfoParser
|
||||
}
|
||||
|
||||
provider := &OAuth2Provider{
|
||||
config: &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Scopes: cfg.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.AuthURL,
|
||||
TokenURL: cfg.TokenURL,
|
||||
},
|
||||
},
|
||||
userInfoURL: cfg.UserInfoURL,
|
||||
userInfoParser: cfg.UserInfoParser,
|
||||
providerName: cfg.ProviderName,
|
||||
states: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Initialize providers map if needed
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
if a.oauth2Providers == nil {
|
||||
a.oauth2Providers = make(map[string]*OAuth2Provider)
|
||||
}
|
||||
|
||||
// Register provider
|
||||
a.oauth2Providers[cfg.ProviderName] = provider
|
||||
a.oauth2ProvidersMutex.Unlock()
|
||||
|
||||
// Start state cleanup goroutine for this provider
|
||||
go provider.cleanupStates()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// OAuth2GetAuthURL returns the OAuth2 authorization URL for redirecting users
|
||||
func (a *DatabaseAuthenticator) OAuth2GetAuthURL(providerName, state string) (string, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store state for validation
|
||||
provider.statesMutex.Lock()
|
||||
provider.states[state] = time.Now().Add(10 * time.Minute)
|
||||
provider.statesMutex.Unlock()
|
||||
|
||||
return provider.config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// OAuth2GenerateState generates a random state string for CSRF protection
|
||||
func (a *DatabaseAuthenticator) OAuth2GenerateState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// OAuth2HandleCallback handles the OAuth2 callback and exchanges code for token
|
||||
func (a *DatabaseAuthenticator) OAuth2HandleCallback(ctx context.Context, providerName, code, state string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate state
|
||||
if !provider.validateState(state) {
|
||||
return nil, fmt.Errorf("invalid state parameter")
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := provider.config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Fetch user info
|
||||
client := provider.config.Client(ctx, token)
|
||||
resp, err := client.Get(provider.userInfoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read user info: %w", err)
|
||||
}
|
||||
|
||||
var userInfo map[string]any
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
// Parse user info
|
||||
userCtx, err := provider.userInfoParser(userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
// Get or create user in database
|
||||
userID, err := a.oauth2GetOrCreateUser(ctx, userCtx, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
userCtx.UserID = userID
|
||||
|
||||
// Create session token
|
||||
sessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
if token.Expiry.After(time.Now()) {
|
||||
expiresAt = token.Expiry
|
||||
}
|
||||
|
||||
// Store session in database
|
||||
err = a.oauth2CreateSession(ctx, sessionToken, userCtx.UserID, token, expiresAt, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = sessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
User: userCtx,
|
||||
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OAuth2GetProviders returns list of configured OAuth2 provider names
|
||||
func (a *DatabaseAuthenticator) OAuth2GetProviders() []string {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providers = append(providers, name)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// getOAuth2Provider retrieves a registered OAuth2 provider by name
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(providerName string) (*OAuth2Provider, error) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil, fmt.Errorf("OAuth2 not configured - call WithOAuth2() first")
|
||||
}
|
||||
|
||||
provider, ok := a.oauth2Providers[providerName]
|
||||
if !ok {
|
||||
// Build provider list without calling OAuth2GetProviders to avoid recursion
|
||||
providerNames := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providerNames = append(providerNames, name)
|
||||
}
|
||||
return nil, fmt.Errorf("OAuth2 provider '%s' not found - available providers: %v", providerName, providerNames)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// oauth2GetOrCreateUser finds or creates a user based on OAuth2 info using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userCtx *UserContext, providerName string) (int, error) {
|
||||
userData := map[string]interface{}{
|
||||
"username": userCtx.UserName,
|
||||
"email": userCtx.Email,
|
||||
"remote_id": userCtx.RemoteID,
|
||||
"user_level": userCtx.UserLevel,
|
||||
"roles": userCtx.Roles,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
userJSON, err := json.Marshal(userData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal user data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return 0, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to get or create user")
|
||||
}
|
||||
|
||||
if userID == nil {
|
||||
return 0, fmt.Errorf("user ID not returned")
|
||||
}
|
||||
|
||||
return *userID, nil
|
||||
}
|
||||
|
||||
// oauth2CreateSession creates a new OAuth2 session using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, sessionToken string, userID int, token *oauth2.Token, expiresAt time.Time, providerName string) error {
|
||||
sessionData := map[string]interface{}{
|
||||
"session_token": sessionToken,
|
||||
"user_id": userID,
|
||||
"access_token": token.AccessToken,
|
||||
"refresh_token": token.RefreshToken,
|
||||
"token_type": token.TokenType,
|
||||
"expires_at": expiresAt,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
sessionJSON, err := json.Marshal(sessionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to create session")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateState validates state using in-memory storage
|
||||
func (p *OAuth2Provider) validateState(state string) bool {
|
||||
p.statesMutex.Lock()
|
||||
defer p.statesMutex.Unlock()
|
||||
|
||||
expiry, ok := p.states[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(expiry) {
|
||||
delete(p.states, state)
|
||||
return false
|
||||
}
|
||||
|
||||
delete(p.states, state) // One-time use
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupStates removes expired states periodically
|
||||
func (p *OAuth2Provider) cleanupStates() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
p.statesMutex.Lock()
|
||||
now := time.Now()
|
||||
for state, expiry := range p.states {
|
||||
if now.After(expiry) {
|
||||
delete(p.states, state)
|
||||
}
|
||||
}
|
||||
p.statesMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// defaultOAuth2UserInfoParser parses standard OAuth2 user info claims
|
||||
func defaultOAuth2UserInfoParser(userInfo map[string]any) (*UserContext, error) {
|
||||
ctx := &UserContext{
|
||||
Claims: userInfo,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
// Extract standard claims
|
||||
if sub, ok := userInfo["sub"].(string); ok {
|
||||
ctx.RemoteID = sub
|
||||
}
|
||||
if email, ok := userInfo["email"].(string); ok {
|
||||
ctx.Email = email
|
||||
// Use email as username if name not available
|
||||
ctx.UserName = strings.Split(email, "@")[0]
|
||||
}
|
||||
if name, ok := userInfo["name"].(string); ok {
|
||||
ctx.UserName = name
|
||||
}
|
||||
if login, ok := userInfo["login"].(string); ok {
|
||||
ctx.UserName = login // GitHub uses "login"
|
||||
}
|
||||
|
||||
if ctx.UserName == "" {
|
||||
return nil, fmt.Errorf("could not extract username from user info")
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// OAuth2RefreshToken refreshes an expired OAuth2 access token using the refresh token
|
||||
// Takes the refresh token and returns a new LoginResponse with updated tokens
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshToken, providerName string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get session by refresh token from database
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
|
||||
// Parse session data
|
||||
var session struct {
|
||||
UserID int `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
if err := json.Unmarshal(sessionData, &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse session data: %w", err)
|
||||
}
|
||||
|
||||
// Create oauth2.Token from stored data
|
||||
oldToken := &oauth2.Token{
|
||||
AccessToken: session.AccessToken,
|
||||
TokenType: session.TokenType,
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: session.Expiry,
|
||||
}
|
||||
|
||||
// Use OAuth2 provider to refresh the token
|
||||
tokenSource := provider.config.TokenSource(ctx, oldToken)
|
||||
newToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token with provider: %w", err)
|
||||
}
|
||||
|
||||
// Generate new session token
|
||||
newSessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate new session token: %w", err)
|
||||
}
|
||||
|
||||
// Update session in database with new tokens
|
||||
updateData := map[string]interface{}{
|
||||
"user_id": session.UserID,
|
||||
"old_refresh_token": refreshToken,
|
||||
"new_session_token": newSessionToken,
|
||||
"new_access_token": newToken.AccessToken,
|
||||
"new_refresh_token": newToken.RefreshToken,
|
||||
"expires_at": newToken.Expiry,
|
||||
}
|
||||
|
||||
updateJSON, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal update data: %w", err)
|
||||
}
|
||||
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
}
|
||||
|
||||
if !updateSuccess {
|
||||
if updateErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *updateErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update session")
|
||||
}
|
||||
|
||||
// Get user data
|
||||
var userSuccess bool
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
}
|
||||
|
||||
if !userSuccess {
|
||||
if userErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *userErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user data")
|
||||
}
|
||||
|
||||
// Parse user context
|
||||
var userCtx UserContext
|
||||
if err := json.Unmarshal(userData, &userCtx); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = newSessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: newSessionToken,
|
||||
RefreshToken: newToken.RefreshToken,
|
||||
User: &userCtx,
|
||||
ExpiresIn: int64(time.Until(newToken.Expiry).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pre-configured OAuth2 factory methods
|
||||
|
||||
// NewGoogleAuthenticator creates a DatabaseAuthenticator configured for Google OAuth2
|
||||
func NewGoogleAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
}
|
||||
|
||||
// NewGitHubAuthenticator creates a DatabaseAuthenticator configured for GitHub OAuth2
|
||||
func NewGitHubAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMicrosoftAuthenticator creates a DatabaseAuthenticator configured for Microsoft OAuth2
|
||||
func NewMicrosoftAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
})
|
||||
}
|
||||
|
||||
// NewFacebookAuthenticator creates a DatabaseAuthenticator configured for Facebook OAuth2
|
||||
func NewFacebookAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMultiProviderAuthenticator creates a DatabaseAuthenticator with all major OAuth2 providers configured
|
||||
func NewMultiProviderAuthenticator(db *sql.DB, configs map[string]OAuth2Config) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
|
||||
//nolint:gocritic // OAuth2Config is copied but kept for API simplicity
|
||||
for _, cfg := range configs {
|
||||
auth.WithOAuth2(cfg)
|
||||
}
|
||||
|
||||
return auth
|
||||
}
|
||||
185
pkg/security/passkey.go
Normal file
185
pkg/security/passkey.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PasskeyCredential represents a stored WebAuthn/FIDO2 credential
|
||||
type PasskeyCredential struct {
|
||||
ID string `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
CredentialID []byte `json:"credential_id"` // Raw credential ID from authenticator
|
||||
PublicKey []byte `json:"public_key"` // COSE public key
|
||||
AttestationType string `json:"attestation_type"` // none, indirect, direct
|
||||
AAGUID []byte `json:"aaguid"` // Authenticator AAGUID
|
||||
SignCount uint32 `json:"sign_count"` // Signature counter
|
||||
CloneWarning bool `json:"clone_warning"` // True if cloning detected
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
BackupEligible bool `json:"backup_eligible"` // Credential can be backed up
|
||||
BackupState bool `json:"backup_state"` // Credential is currently backed up
|
||||
Name string `json:"name,omitempty"` // User-friendly name
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
// PasskeyRegistrationOptions contains options for beginning passkey registration
|
||||
type PasskeyRegistrationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
RelyingParty PasskeyRelyingParty `json:"rp"`
|
||||
User PasskeyUser `json:"user"`
|
||||
PubKeyCredParams []PasskeyCredentialParam `json:"pubKeyCredParams"`
|
||||
Timeout int64 `json:"timeout,omitempty"` // Milliseconds
|
||||
ExcludeCredentials []PasskeyCredentialDescriptor `json:"excludeCredentials,omitempty"`
|
||||
AuthenticatorSelection *PasskeyAuthenticatorSelection `json:"authenticatorSelection,omitempty"`
|
||||
Attestation string `json:"attestation,omitempty"` // none, indirect, direct, enterprise
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationOptions contains options for beginning passkey authentication
|
||||
type PasskeyAuthenticationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
Timeout int64 `json:"timeout,omitempty"`
|
||||
RelyingPartyID string `json:"rpId,omitempty"`
|
||||
AllowCredentials []PasskeyCredentialDescriptor `json:"allowCredentials,omitempty"`
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyRelyingParty identifies the relying party
|
||||
type PasskeyRelyingParty struct {
|
||||
ID string `json:"id"` // Domain (e.g., "example.com")
|
||||
Name string `json:"name"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyUser identifies the user
|
||||
type PasskeyUser struct {
|
||||
ID []byte `json:"id"` // User handle (unique, persistent)
|
||||
Name string `json:"name"` // Username
|
||||
DisplayName string `json:"displayName"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyCredentialParam specifies supported public key algorithm
|
||||
type PasskeyCredentialParam struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
Alg int `json:"alg"` // COSE algorithm identifier (e.g., -7 for ES256, -257 for RS256)
|
||||
}
|
||||
|
||||
// PasskeyCredentialDescriptor describes a credential
|
||||
type PasskeyCredentialDescriptor struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
ID []byte `json:"id"` // Credential ID
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorSelection specifies authenticator requirements
|
||||
type PasskeyAuthenticatorSelection struct {
|
||||
AuthenticatorAttachment string `json:"authenticatorAttachment,omitempty"` // platform, cross-platform
|
||||
RequireResidentKey bool `json:"requireResidentKey,omitempty"`
|
||||
ResidentKey string `json:"residentKey,omitempty"` // discouraged, preferred, required
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
}
|
||||
|
||||
// PasskeyRegistrationResponse contains the client's registration response
|
||||
type PasskeyRegistrationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAttestationResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAttestationResponse contains attestation data
|
||||
type PasskeyAuthenticatorAttestationResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AttestationObject []byte `json:"attestationObject"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationResponse contains the client's authentication response
|
||||
type PasskeyAuthenticationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAssertionResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAssertionResponse contains assertion data
|
||||
type PasskeyAuthenticatorAssertionResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AuthenticatorData []byte `json:"authenticatorData"`
|
||||
Signature []byte `json:"signature"`
|
||||
UserHandle []byte `json:"userHandle,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyProvider handles passkey registration and authentication
|
||||
type PasskeyProvider interface {
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error)
|
||||
|
||||
// CompleteRegistration verifies and stores a new passkey credential
|
||||
CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error)
|
||||
|
||||
// BeginAuthentication creates authentication options for passkey login
|
||||
BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error)
|
||||
|
||||
// CompleteAuthentication verifies a passkey assertion and returns the user
|
||||
CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error)
|
||||
|
||||
// GetCredentials returns all passkey credentials for a user
|
||||
GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error)
|
||||
|
||||
// DeleteCredential removes a passkey credential
|
||||
DeleteCredential(ctx context.Context, userID int, credentialID string) error
|
||||
|
||||
// UpdateCredentialName updates the friendly name of a credential
|
||||
UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error
|
||||
}
|
||||
|
||||
// PasskeyLoginRequest contains passkey authentication data
|
||||
type PasskeyLoginRequest struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
}
|
||||
|
||||
// PasskeyRegisterRequest contains passkey registration data
|
||||
type PasskeyRegisterRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
CredentialName string `json:"credential_name,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyBeginRegistrationRequest contains options for starting passkey registration
|
||||
type PasskeyBeginRegistrationRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// PasskeyBeginAuthenticationRequest contains options for starting passkey authentication
|
||||
type PasskeyBeginAuthenticationRequest struct {
|
||||
Username string `json:"username,omitempty"` // Optional for resident key flow
|
||||
}
|
||||
|
||||
// ParsePasskeyRegistrationResponse parses a JSON passkey registration response
|
||||
func ParsePasskeyRegistrationResponse(data []byte) (*PasskeyRegistrationResponse, error) {
|
||||
var response PasskeyRegistrationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// ParsePasskeyAuthenticationResponse parses a JSON passkey authentication response
|
||||
func ParsePasskeyAuthenticationResponse(data []byte) (*PasskeyAuthenticationResponse, error) {
|
||||
var response PasskeyAuthenticationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
432
pkg/security/passkey_examples.go
Normal file
432
pkg/security/passkey_examples.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PasskeyAuthenticationExample demonstrates passkey (WebAuthn/FIDO2) authentication
|
||||
func PasskeyAuthenticationExample() {
|
||||
// Setup database connection
|
||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||
|
||||
// Create passkey provider
|
||||
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com", // Your domain
|
||||
RPName: "Example Application", // Display name
|
||||
RPOrigin: "https://example.com", // Expected origin
|
||||
Timeout: 60000, // 60 seconds
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
// Option 1: Pass during creation
|
||||
_ = NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Option 2: Use WithPasskey method
|
||||
auth := NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// === REGISTRATION FLOW ===
|
||||
|
||||
// Step 1: Begin registration
|
||||
regOptions, _ := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
|
||||
// Send regOptions to client as JSON
|
||||
// Client will call navigator.credentials.create() with these options
|
||||
_ = regOptions
|
||||
|
||||
// Step 2: Complete registration (after client returns credential)
|
||||
// This would come from the client's navigator.credentials.create() response
|
||||
clientResponse := PasskeyRegistrationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAttestationResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AttestationObject: []byte("..."),
|
||||
},
|
||||
Transports: []string{"internal"},
|
||||
}
|
||||
|
||||
credential, _ := auth.CompletePasskeyRegistration(ctx, PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: regOptions.Challenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
|
||||
fmt.Printf("Registered credential: %s\n", credential.ID)
|
||||
|
||||
// === AUTHENTICATION FLOW ===
|
||||
|
||||
// Step 1: Begin authentication
|
||||
authOptions, _ := auth.BeginPasskeyAuthentication(ctx, PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional - omit for resident key flow
|
||||
})
|
||||
|
||||
// Send authOptions to client as JSON
|
||||
// Client will call navigator.credentials.get() with these options
|
||||
_ = authOptions
|
||||
|
||||
// Step 2: Complete authentication (after client returns assertion)
|
||||
// This would come from the client's navigator.credentials.get() response
|
||||
clientAssertion := PasskeyAuthenticationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAssertionResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AuthenticatorData: []byte("..."),
|
||||
Signature: []byte("..."),
|
||||
},
|
||||
}
|
||||
|
||||
loginResponse, _ := auth.LoginWithPasskey(ctx, PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: authOptions.Challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
|
||||
fmt.Printf("Logged in user: %s with token: %s\n",
|
||||
loginResponse.User.UserName, loginResponse.Token)
|
||||
|
||||
// === CREDENTIAL MANAGEMENT ===
|
||||
|
||||
// Get all credentials for a user
|
||||
credentials, _ := auth.GetPasskeyCredentials(ctx, 1)
|
||||
for i := range credentials {
|
||||
fmt.Printf("Credential: %s (created: %s, last used: %s)\n",
|
||||
credentials[i].Name, credentials[i].CreatedAt, credentials[i].LastUsedAt)
|
||||
}
|
||||
|
||||
// Update credential name
|
||||
_ = auth.UpdatePasskeyCredentialName(ctx, 1, credential.ID, "My New iPhone")
|
||||
|
||||
// Delete credential
|
||||
_ = auth.DeletePasskeyCredential(ctx, 1, credential.ID)
|
||||
}
|
||||
|
||||
// PasskeyHTTPHandlersExample shows HTTP handlers for passkey authentication
|
||||
func PasskeyHTTPHandlersExample(auth *DatabaseAuthenticator) {
|
||||
// Store challenges in session/cache in production
|
||||
challenges := make(map[string][]byte)
|
||||
|
||||
// Begin registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyRegistration(r.Context(), PasskeyBeginRegistrationRequest{
|
||||
UserID: req.UserID,
|
||||
Username: req.Username,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-123"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
CredentialName string `json:"credential_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-123"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
credential, err := auth.CompletePasskeyRegistration(r.Context(), PasskeyRegisterRequest{
|
||||
UserID: req.UserID,
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
CredentialName: req.CredentialName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credential)
|
||||
})
|
||||
|
||||
// Begin authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"` // Optional
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyAuthentication(r.Context(), PasskeyBeginAuthenticationRequest{
|
||||
Username: req.Username,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-456"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-456"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
loginResponse, err := auth.LoginWithPasskey(r.Context(), PasskeyLoginRequest{
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": r.RemoteAddr,
|
||||
"user_agent": r.UserAgent(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResponse.Token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(loginResponse)
|
||||
})
|
||||
|
||||
// List credentials endpoint
|
||||
http.HandleFunc("/api/passkey/credentials", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user from authenticated session
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
credentials, err := auth.GetPasskeyCredentials(r.Context(), userCtx.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credentials)
|
||||
})
|
||||
|
||||
// Delete credential endpoint
|
||||
http.HandleFunc("/api/passkey/credentials/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
CredentialID string `json:"credential_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
err = auth.DeletePasskeyCredential(r.Context(), userCtx.UserID, req.CredentialID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// PasskeyClientSideExample shows the client-side JavaScript code needed
|
||||
func PasskeyClientSideExample() string {
|
||||
return `
|
||||
// === CLIENT-SIDE JAVASCRIPT FOR PASSKEY AUTHENTICATION ===
|
||||
|
||||
// Helper function to convert base64 to ArrayBuffer
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
// Helper function to convert ArrayBuffer to base64
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
|
||||
// === REGISTRATION ===
|
||||
|
||||
async function registerPasskey(userId, username, displayName) {
|
||||
// Step 1: Get registration options from server
|
||||
const optionsResponse = await fetch('/api/passkey/register/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ user_id: userId, username, display_name: displayName })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
if (options.excludeCredentials) {
|
||||
options.excludeCredentials = options.excludeCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Create credential using WebAuthn API
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send credential to server
|
||||
const credentialResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
attestationObject: arrayBufferToBase64(credential.response.attestationObject)
|
||||
},
|
||||
transports: credential.response.getTransports ? credential.response.getTransports() : []
|
||||
};
|
||||
|
||||
const completeResponse = await fetch('/api/passkey/register/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
user_id: userId,
|
||||
response: credentialResponse,
|
||||
credential_name: 'My Device'
|
||||
})
|
||||
});
|
||||
|
||||
return await completeResponse.json();
|
||||
}
|
||||
|
||||
// === AUTHENTICATION ===
|
||||
|
||||
async function loginWithPasskey(username) {
|
||||
// Step 1: Get authentication options from server
|
||||
const optionsResponse = await fetch('/api/passkey/login/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ username })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
if (options.allowCredentials) {
|
||||
options.allowCredentials = options.allowCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Get credential using WebAuthn API
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send assertion to server
|
||||
const assertionResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
authenticatorData: arrayBufferToBase64(credential.response.authenticatorData),
|
||||
signature: arrayBufferToBase64(credential.response.signature),
|
||||
userHandle: credential.response.userHandle ? arrayBufferToBase64(credential.response.userHandle) : null
|
||||
}
|
||||
};
|
||||
|
||||
const loginResponse = await fetch('/api/passkey/login/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ response: assertionResponse })
|
||||
});
|
||||
|
||||
return await loginResponse.json();
|
||||
}
|
||||
|
||||
// === USAGE ===
|
||||
|
||||
// Register a new passkey
|
||||
document.getElementById('register-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await registerPasskey(1, 'alice', 'Alice Smith');
|
||||
console.log('Passkey registered:', result);
|
||||
} catch (error) {
|
||||
console.error('Registration failed:', error);
|
||||
}
|
||||
});
|
||||
|
||||
// Login with passkey
|
||||
document.getElementById('login-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await loginWithPasskey('alice');
|
||||
console.log('Logged in:', result);
|
||||
} catch (error) {
|
||||
console.error('Login failed:', error);
|
||||
}
|
||||
});
|
||||
`
|
||||
}
|
||||
405
pkg/security/passkey_provider.go
Normal file
405
pkg/security/passkey_provider.go
Normal file
@@ -0,0 +1,405 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
type DatabasePasskeyProviderOptions struct {
|
||||
// RPID is the Relying Party ID (typically your domain, e.g., "example.com")
|
||||
RPID string
|
||||
// RPName is the display name for your relying party
|
||||
RPName string
|
||||
// RPOrigin is the expected origin (e.g., "https://example.com")
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) *DatabasePasskeyProvider {
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
challenge := make([]byte, 32)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
|
||||
// Get existing credentials to exclude
|
||||
credentials, err := p.GetCredentials(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get existing credentials: %w", err)
|
||||
}
|
||||
|
||||
excludeCredentials := make([]PasskeyCredentialDescriptor, 0, len(credentials))
|
||||
for i := range credentials {
|
||||
excludeCredentials = append(excludeCredentials, PasskeyCredentialDescriptor{
|
||||
Type: "public-key",
|
||||
ID: credentials[i].CredentialID,
|
||||
Transports: credentials[i].Transports,
|
||||
})
|
||||
}
|
||||
|
||||
// Create user handle (persistent user ID)
|
||||
userHandle := []byte(fmt.Sprintf("user_%d", userID))
|
||||
|
||||
return &PasskeyRegistrationOptions{
|
||||
Challenge: challenge,
|
||||
RelyingParty: PasskeyRelyingParty{
|
||||
ID: p.rpID,
|
||||
Name: p.rpName,
|
||||
},
|
||||
User: PasskeyUser{
|
||||
ID: userHandle,
|
||||
Name: username,
|
||||
DisplayName: displayName,
|
||||
},
|
||||
PubKeyCredParams: []PasskeyCredentialParam{
|
||||
{Type: "public-key", Alg: -7}, // ES256 (ECDSA with SHA-256)
|
||||
{Type: "public-key", Alg: -257}, // RS256 (RSASSA-PKCS1-v1_5 with SHA-256)
|
||||
},
|
||||
Timeout: p.timeout,
|
||||
ExcludeCredentials: excludeCredentials,
|
||||
AuthenticatorSelection: &PasskeyAuthenticatorSelection{
|
||||
RequireResidentKey: false,
|
||||
ResidentKey: "preferred",
|
||||
UserVerification: "preferred",
|
||||
},
|
||||
Attestation: "none",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteRegistration verifies and stores a new passkey credential
|
||||
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||
// like github.com/go-webauthn/webauthn to properly verify attestation and parse credentials.
|
||||
func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error) {
|
||||
// TODO: Implement full WebAuthn verification
|
||||
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||
// 2. Parse and verify attestationObject
|
||||
// 3. Extract public key and credential ID
|
||||
// 4. Verify attestation signature (if not "none")
|
||||
|
||||
// For now, this is a placeholder that stores the credential data
|
||||
// In production, you MUST use a proper WebAuthn library
|
||||
|
||||
credData := map[string]any{
|
||||
"user_id": userID,
|
||||
"credential_id": base64.StdEncoding.EncodeToString(response.RawID),
|
||||
"public_key": base64.StdEncoding.EncodeToString(response.Response.AttestationObject),
|
||||
"attestation_type": "none",
|
||||
"sign_count": 0,
|
||||
"transports": response.Transports,
|
||||
"backup_eligible": false,
|
||||
"backup_state": false,
|
||||
"name": "Passkey",
|
||||
}
|
||||
|
||||
credJSON, err := json.Marshal(credData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal credential data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to store credential")
|
||||
}
|
||||
|
||||
return &PasskeyCredential{
|
||||
ID: fmt.Sprintf("%d", credentialID.Int64),
|
||||
UserID: userID,
|
||||
CredentialID: response.RawID,
|
||||
PublicKey: response.Response.AttestationObject,
|
||||
AttestationType: "none",
|
||||
Transports: response.Transports,
|
||||
CreatedAt: time.Now(),
|
||||
LastUsedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginAuthentication creates authentication options for passkey login
|
||||
func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error) {
|
||||
// Generate challenge
|
||||
challenge := make([]byte, 32)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||
}
|
||||
|
||||
// If username is provided, get user's credentials
|
||||
var allowCredentials []PasskeyCredentialDescriptor
|
||||
if username != "" {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get credentials")
|
||||
}
|
||||
|
||||
// Parse credentials
|
||||
var creds []struct {
|
||||
ID string `json:"credential_id"`
|
||||
Transports []string `json:"transports"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(credentialsJSON.String), &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||
}
|
||||
|
||||
allowCredentials = make([]PasskeyCredentialDescriptor, 0, len(creds))
|
||||
for _, cred := range creds {
|
||||
credID, err := base64.StdEncoding.DecodeString(cred.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
allowCredentials = append(allowCredentials, PasskeyCredentialDescriptor{
|
||||
Type: "public-key",
|
||||
ID: credID,
|
||||
Transports: cred.Transports,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &PasskeyAuthenticationOptions{
|
||||
Challenge: challenge,
|
||||
Timeout: p.timeout,
|
||||
RelyingPartyID: p.rpID,
|
||||
AllowCredentials: allowCredentials,
|
||||
UserVerification: "preferred",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteAuthentication verifies a passkey assertion and returns the user ID
|
||||
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||
// like github.com/go-webauthn/webauthn to properly verify the assertion signature.
|
||||
func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error) {
|
||||
// TODO: Implement full WebAuthn verification
|
||||
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||
// 2. Verify authenticatorData
|
||||
// 3. Verify signature using stored public key
|
||||
// 4. Update sign counter and check for cloning
|
||||
|
||||
// Get credential from database
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return 0, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return 0, fmt.Errorf("credential not found")
|
||||
}
|
||||
|
||||
// Parse credential
|
||||
var cred struct {
|
||||
UserID int `json:"user_id"`
|
||||
SignCount uint32 `json:"sign_count"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(credentialJSON.String), &cred); err != nil {
|
||||
return 0, fmt.Errorf("failed to parse credential: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Verify signature here
|
||||
// For now, we'll just update the counter as a placeholder
|
||||
|
||||
// Update counter (in production, this should be done after successful verification)
|
||||
newCounter := cred.SignCount + 1
|
||||
var updateSuccess bool
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
|
||||
if cloneWarning.Valid && cloneWarning.Bool {
|
||||
return 0, fmt.Errorf("credential cloning detected")
|
||||
}
|
||||
|
||||
return cred.UserID, nil
|
||||
}
|
||||
|
||||
// GetCredentials returns all passkey credentials for a user
|
||||
func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get credentials")
|
||||
}
|
||||
|
||||
// Parse credentials
|
||||
var rawCreds []struct {
|
||||
ID int `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
CredentialID string `json:"credential_id"`
|
||||
PublicKey string `json:"public_key"`
|
||||
AttestationType string `json:"attestation_type"`
|
||||
AAGUID string `json:"aaguid"`
|
||||
SignCount uint32 `json:"sign_count"`
|
||||
CloneWarning bool `json:"clone_warning"`
|
||||
Transports []string `json:"transports"`
|
||||
BackupEligible bool `json:"backup_eligible"`
|
||||
BackupState bool `json:"backup_state"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(credentialsJSON.String), &rawCreds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||
}
|
||||
|
||||
credentials := make([]PasskeyCredential, 0, len(rawCreds))
|
||||
for i := range rawCreds {
|
||||
raw := rawCreds[i]
|
||||
credID, err := base64.StdEncoding.DecodeString(raw.CredentialID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pubKey, err := base64.StdEncoding.DecodeString(raw.PublicKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
aaguid, _ := base64.StdEncoding.DecodeString(raw.AAGUID)
|
||||
|
||||
credentials = append(credentials, PasskeyCredential{
|
||||
ID: fmt.Sprintf("%d", raw.ID),
|
||||
UserID: raw.UserID,
|
||||
CredentialID: credID,
|
||||
PublicKey: pubKey,
|
||||
AttestationType: raw.AttestationType,
|
||||
AAGUID: aaguid,
|
||||
SignCount: raw.SignCount,
|
||||
CloneWarning: raw.CloneWarning,
|
||||
Transports: raw.Transports,
|
||||
BackupEligible: raw.BackupEligible,
|
||||
BackupState: raw.BackupState,
|
||||
Name: raw.Name,
|
||||
CreatedAt: raw.CreatedAt,
|
||||
LastUsedAt: raw.LastUsedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
// DeleteCredential removes a passkey credential
|
||||
func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID int, credentialID string) error {
|
||||
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credential ID: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to delete credential")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCredentialName updates the friendly name of a credential
|
||||
func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid credential ID: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to update credential name")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
330
pkg/security/passkey_test.go
Normal file
330
pkg/security/passkey_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
func TestDatabasePasskeyProvider_BeginRegistration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock get credentials query
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := provider.BeginRegistration(ctx, 1, "testuser", "Test User")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginRegistration failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.RelyingParty.ID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingParty.ID)
|
||||
}
|
||||
|
||||
if opts.User.Name != "testuser" {
|
||||
t.Errorf("expected username 'testuser', got '%s'", opts.User.Name)
|
||||
}
|
||||
|
||||
if len(opts.Challenge) != 32 {
|
||||
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||
}
|
||||
|
||||
if len(opts.PubKeyCredParams) != 2 {
|
||||
t.Errorf("expected 2 credential params, got %d", len(opts.PubKeyCredParams))
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_BeginAuthentication(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock get credentials by username query
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user_id", "p_credentials"}).
|
||||
AddRow(true, nil, 1, `[{"credential_id":"YWJjZGVm","transports":["internal"]}]`)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username`).
|
||||
WithArgs("testuser").
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := provider.BeginAuthentication(ctx, "testuser")
|
||||
if err != nil {
|
||||
t.Fatalf("BeginAuthentication failed: %v", err)
|
||||
}
|
||||
|
||||
if opts.RelyingPartyID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingPartyID)
|
||||
}
|
||||
|
||||
if len(opts.Challenge) != 32 {
|
||||
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||
}
|
||||
|
||||
if len(opts.AllowCredentials) != 1 {
|
||||
t.Errorf("expected 1 allowed credential, got %d", len(opts.AllowCredentials))
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_GetCredentials(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
credentialsJSON := `[{
|
||||
"id": 1,
|
||||
"user_id": 1,
|
||||
"credential_id": "YWJjZGVmMTIzNDU2",
|
||||
"public_key": "cHVibGlja2V5",
|
||||
"attestation_type": "none",
|
||||
"aaguid": "",
|
||||
"sign_count": 5,
|
||||
"clone_warning": false,
|
||||
"transports": ["internal"],
|
||||
"backup_eligible": true,
|
||||
"backup_state": false,
|
||||
"name": "My Phone",
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"last_used_at": "2026-01-31T00:00:00Z"
|
||||
}]`
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, credentialsJSON)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
credentials, err := provider.GetCredentials(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if len(credentials) != 1 {
|
||||
t.Fatalf("expected 1 credential, got %d", len(credentials))
|
||||
}
|
||||
|
||||
cred := credentials[0]
|
||||
if cred.UserID != 1 {
|
||||
t.Errorf("expected user ID 1, got %d", cred.UserID)
|
||||
}
|
||||
if cred.Name != "My Phone" {
|
||||
t.Errorf("expected name 'My Phone', got '%s'", cred.Name)
|
||||
}
|
||||
if cred.SignCount != 5 {
|
||||
t.Errorf("expected sign count 5, got %d", cred.SignCount)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_DeleteCredential(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||
AddRow(true, nil)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_delete_credential`).
|
||||
WithArgs(1, sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
err = provider.DeleteCredential(ctx, 1, "YWJjZGVmMTIzNDU2")
|
||||
if err != nil {
|
||||
t.Errorf("DeleteCredential failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabasePasskeyProvider_UpdateCredentialName(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||
AddRow(true, nil)
|
||||
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_update_name`).
|
||||
WithArgs(1, sqlmock.AnyArg(), "New Name").
|
||||
WillReturnRows(rows)
|
||||
|
||||
err = provider.UpdateCredentialName(ctx, 1, "YWJjZGVmMTIzNDU2", "New Name")
|
||||
if err != nil {
|
||||
t.Errorf("UpdateCredentialName failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticator_PasskeyMethods(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BeginPasskeyRegistration", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
opts, err := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
DisplayName: "Test User",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("BeginPasskeyRegistration failed: %v", err)
|
||||
}
|
||||
|
||||
if opts == nil {
|
||||
t.Error("expected options, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetPasskeyCredentials", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||
AddRow(true, nil, "[]")
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||
WithArgs(1).
|
||||
WillReturnRows(rows)
|
||||
|
||||
credentials, err := auth.GetPasskeyCredentials(ctx, 1)
|
||||
if err != nil {
|
||||
t.Errorf("GetPasskeyCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if credentials == nil {
|
||||
t.Error("expected credentials slice, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticator_WithoutPasskey(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mock db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
DisplayName: "Test User",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error when passkey provider not configured, got nil")
|
||||
}
|
||||
|
||||
expectedMsg := "passkey provider not configured"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("expected error '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasskeyProvider_NilDB(t *testing.T) {
|
||||
// This test verifies that the provider can be created with nil DB
|
||||
// but operations will fail. In production, always provide a valid DB.
|
||||
var db *sql.DB
|
||||
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
})
|
||||
|
||||
if provider == nil {
|
||||
t.Error("expected provider to be created even with nil DB")
|
||||
}
|
||||
|
||||
// Verify that the provider has the correct configuration
|
||||
if provider.rpID != "example.com" {
|
||||
t.Errorf("expected RP ID 'example.com', got '%s'", provider.rpID)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
@@ -60,10 +61,19 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
oauth2ProvidersMutex sync.RWMutex
|
||||
|
||||
// Passkey provider (optional)
|
||||
passkeyProvider PasskeyProvider
|
||||
}
|
||||
|
||||
// DatabaseAuthenticatorOptions configures the database authenticator
|
||||
@@ -73,6 +83,8 @@ type DatabaseAuthenticatorOptions struct {
|
||||
CacheTTL time.Duration
|
||||
// Cache is an optional cache instance. If nil, uses the default cache
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -92,9 +104,10 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
}
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,6 +145,41 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// Register implements Registrable interface
|
||||
func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error) {
|
||||
// Convert RegisterRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("registration failed")
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse register response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Convert LogoutRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -654,3 +702,135 @@ func generateRandomString(length int) string {
|
||||
// }
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// Passkey authentication methods
|
||||
// ==============================
|
||||
|
||||
// WithPasskey configures the DatabaseAuthenticator with a passkey provider
|
||||
func (a *DatabaseAuthenticator) WithPasskey(provider PasskeyProvider) *DatabaseAuthenticator {
|
||||
a.passkeyProvider = provider
|
||||
return a
|
||||
}
|
||||
|
||||
// BeginPasskeyRegistration initiates passkey registration for a user
|
||||
func (a *DatabaseAuthenticator) BeginPasskeyRegistration(ctx context.Context, req PasskeyBeginRegistrationRequest) (*PasskeyRegistrationOptions, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.BeginRegistration(ctx, req.UserID, req.Username, req.DisplayName)
|
||||
}
|
||||
|
||||
// CompletePasskeyRegistration completes passkey registration
|
||||
func (a *DatabaseAuthenticator) CompletePasskeyRegistration(ctx context.Context, req PasskeyRegisterRequest) (*PasskeyCredential, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
|
||||
cred, err := a.passkeyProvider.CompleteRegistration(ctx, req.UserID, req.Response, req.ExpectedChallenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update credential name if provided
|
||||
if req.CredentialName != "" && cred.ID != "" {
|
||||
_ = a.passkeyProvider.UpdateCredentialName(ctx, req.UserID, cred.ID, req.CredentialName)
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// BeginPasskeyAuthentication initiates passkey authentication
|
||||
func (a *DatabaseAuthenticator) BeginPasskeyAuthentication(ctx context.Context, req PasskeyBeginAuthenticationRequest) (*PasskeyAuthenticationOptions, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.BeginAuthentication(ctx, req.Username)
|
||||
}
|
||||
|
||||
// LoginWithPasskey authenticates a user using a passkey and creates a session
|
||||
func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req PasskeyLoginRequest) (*LoginResponse, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
|
||||
// Verify passkey assertion
|
||||
userID, err := a.passkeyProvider.CompleteAuthentication(ctx, req.Response, req.ExpectedChallenge)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
func (a *DatabaseAuthenticator) GetPasskeyCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||
if a.passkeyProvider == nil {
|
||||
return nil, fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.GetCredentials(ctx, userID)
|
||||
}
|
||||
|
||||
// DeletePasskeyCredential removes a passkey credential
|
||||
func (a *DatabaseAuthenticator) DeletePasskeyCredential(ctx context.Context, userID int, credentialID string) error {
|
||||
if a.passkeyProvider == nil {
|
||||
return fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.DeleteCredential(ctx, userID, credentialID)
|
||||
}
|
||||
|
||||
// UpdatePasskeyCredentialName updates the friendly name of a credential
|
||||
func (a *DatabaseAuthenticator) UpdatePasskeyCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||
if a.passkeyProvider == nil {
|
||||
return fmt.Errorf("passkey provider not configured")
|
||||
}
|
||||
return a.passkeyProvider.UpdateCredentialName(ctx, userID, credentialID, name)
|
||||
}
|
||||
|
||||
@@ -635,6 +635,94 @@ func TestDatabaseAuthenticator(t *testing.T) {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful registration", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "newuser",
|
||||
Password: "password123",
|
||||
Email: "newuser@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"abc123","user":{"user_id":1,"user_name":"newuser","email":"newuser@example.com"},"expires_in":86400}`)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := auth.Register(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if resp.Token != "abc123" {
|
||||
t.Errorf("expected token abc123, got %s", resp.Token)
|
||||
}
|
||||
if resp.User.UserName != "newuser" {
|
||||
t.Errorf("expected username newuser, got %s", resp.User.UserName)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate username", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "existinguser",
|
||||
Password: "password123",
|
||||
Email: "new@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(false, "Username already exists", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := auth.Register(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate username")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate email", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
req := RegisterRequest{
|
||||
Username: "newuser2",
|
||||
Password: "password123",
|
||||
Email: "existing@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(false, "Email already exists", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(rows)
|
||||
|
||||
_, err := auth.Register(ctx, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate email")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test DatabaseAuthenticator RefreshToken
|
||||
|
||||
188
pkg/security/totp.go
Normal file
188
pkg/security/totp.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"math"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TwoFactorAuthProvider defines interface for 2FA operations
|
||||
type TwoFactorAuthProvider interface {
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error)
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
Validate2FACode(secret string, code string) (bool, error)
|
||||
|
||||
// Enable2FA activates 2FA for a user (store secret in your database)
|
||||
Enable2FA(userID int, secret string, backupCodes []string) error
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
Disable2FA(userID int) error
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
Get2FAStatus(userID int) (bool, error)
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
Get2FASecret(userID int) (string, error)
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
GenerateBackupCodes(userID int, count int) ([]string, error)
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
ValidateBackupCode(userID int, code string) (bool, error)
|
||||
}
|
||||
|
||||
// TwoFactorSecret contains 2FA setup information
|
||||
type TwoFactorSecret struct {
|
||||
Secret string `json:"secret"` // Base32 encoded secret
|
||||
QRCodeURL string `json:"qr_code_url"` // URL for QR code generation
|
||||
BackupCodes []string `json:"backup_codes"` // One-time backup codes
|
||||
Issuer string `json:"issuer"` // Application name
|
||||
AccountName string `json:"account_name"` // User identifier (email/username)
|
||||
}
|
||||
|
||||
// TwoFactorConfig holds TOTP configuration
|
||||
type TwoFactorConfig struct {
|
||||
Algorithm string // SHA1, SHA256, SHA512
|
||||
Digits int // Number of digits in code (6 or 8)
|
||||
Period int // Time step in seconds (default 30)
|
||||
SkewWindow int // Number of time steps to check before/after (default 1)
|
||||
}
|
||||
|
||||
// DefaultTwoFactorConfig returns standard TOTP configuration
|
||||
func DefaultTwoFactorConfig() *TwoFactorConfig {
|
||||
return &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// TOTPGenerator handles TOTP code generation and validation
|
||||
type TOTPGenerator struct {
|
||||
config *TwoFactorConfig
|
||||
}
|
||||
|
||||
// NewTOTPGenerator creates a new TOTP generator with config
|
||||
func NewTOTPGenerator(config *TwoFactorConfig) *TOTPGenerator {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &TOTPGenerator{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecret creates a random base32-encoded secret
|
||||
func (t *TOTPGenerator) GenerateSecret() (string, error) {
|
||||
secret := make([]byte, 20)
|
||||
_, err := rand.Read(secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random secret: %w", err)
|
||||
}
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
|
||||
}
|
||||
|
||||
// GenerateQRCodeURL creates a URL for QR code generation
|
||||
func (t *TOTPGenerator) GenerateQRCodeURL(secret, issuer, accountName string) string {
|
||||
params := url.Values{}
|
||||
params.Set("secret", secret)
|
||||
params.Set("issuer", issuer)
|
||||
params.Set("algorithm", t.config.Algorithm)
|
||||
params.Set("digits", fmt.Sprintf("%d", t.config.Digits))
|
||||
params.Set("period", fmt.Sprintf("%d", t.config.Period))
|
||||
|
||||
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, accountName))
|
||||
return fmt.Sprintf("otpauth://totp/%s?%s", label, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateCode creates a TOTP code for a given time
|
||||
func (t *TOTPGenerator) GenerateCode(secret string, timestamp time.Time) (string, error) {
|
||||
// Decode secret
|
||||
key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid secret: %w", err)
|
||||
}
|
||||
|
||||
// Calculate counter (time steps since Unix epoch)
|
||||
counter := uint64(timestamp.Unix()) / uint64(t.config.Period)
|
||||
|
||||
// Generate HMAC
|
||||
h := t.getHashFunc()
|
||||
mac := hmac.New(h, key)
|
||||
|
||||
// Convert counter to 8-byte array
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, counter)
|
||||
mac.Write(buf)
|
||||
|
||||
sum := mac.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := sum[len(sum)-1] & 0x0f
|
||||
truncated := binary.BigEndian.Uint32(sum[offset:]) & 0x7fffffff
|
||||
|
||||
// Generate code with specified digits
|
||||
code := truncated % uint32(math.Pow10(t.config.Digits))
|
||||
|
||||
format := fmt.Sprintf("%%0%dd", t.config.Digits)
|
||||
return fmt.Sprintf(format, code), nil
|
||||
}
|
||||
|
||||
// ValidateCode checks if a code is valid for the secret
|
||||
func (t *TOTPGenerator) ValidateCode(secret, code string) (bool, error) {
|
||||
now := time.Now()
|
||||
|
||||
// Check current time and skew window
|
||||
for i := -t.config.SkewWindow; i <= t.config.SkewWindow; i++ {
|
||||
timestamp := now.Add(time.Duration(i*t.config.Period) * time.Second)
|
||||
expected, err := t.GenerateCode(secret, timestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if code == expected {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// getHashFunc returns the hash function based on algorithm
|
||||
func (t *TOTPGenerator) getHashFunc() func() hash.Hash {
|
||||
switch strings.ToUpper(t.config.Algorithm) {
|
||||
case "SHA256":
|
||||
return sha256.New
|
||||
case "SHA512":
|
||||
return sha512.New
|
||||
default:
|
||||
return sha1.New
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates random backup codes
|
||||
func GenerateBackupCodes(count int) ([]string, error) {
|
||||
codes := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
code := make([]byte, 4)
|
||||
_, err := rand.Read(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup code: %w", err)
|
||||
}
|
||||
codes[i] = fmt.Sprintf("%08X", binary.BigEndian.Uint32(code))
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
399
pkg/security/totp_integration_test.go
Normal file
399
pkg/security/totp_integration_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
// MockAuthenticator is a simple authenticator for testing 2FA
|
||||
type MockAuthenticator struct {
|
||||
users map[string]*security.UserContext
|
||||
}
|
||||
|
||||
func NewMockAuthenticator() *MockAuthenticator {
|
||||
return &MockAuthenticator{
|
||||
users: map[string]*security.UserContext{
|
||||
"testuser": {
|
||||
UserID: 1,
|
||||
UserName: "testuser",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
|
||||
user, exists := m.users[req.Username]
|
||||
if !exists || req.Password != "password" {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return &security.LoginResponse{
|
||||
Token: "mock-token",
|
||||
RefreshToken: "mock-refresh-token",
|
||||
User: user,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
return m.users["testuser"], nil
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Setup(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
if secret.Secret == "" {
|
||||
t.Error("Setup2FA() returned empty secret")
|
||||
}
|
||||
|
||||
if secret.QRCodeURL == "" {
|
||||
t.Error("Setup2FA() returned empty QR code URL")
|
||||
}
|
||||
|
||||
if len(secret.BackupCodes) == 0 {
|
||||
t.Error("Setup2FA() returned no backup codes")
|
||||
}
|
||||
|
||||
if secret.Issuer != "TestApp" {
|
||||
t.Errorf("Setup2FA() Issuer = %s, want TestApp", secret.Issuer)
|
||||
}
|
||||
|
||||
if secret.AccountName != "test@example.com" {
|
||||
t.Errorf("Setup2FA() AccountName = %s, want test@example.com", secret.AccountName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Enable2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Generate valid code
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, err := totp.GenerateCode(secret.Secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable 2FA with valid code
|
||||
err = tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
if err != nil {
|
||||
t.Errorf("Enable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify 2FA is enabled
|
||||
status, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if !status {
|
||||
t.Error("Enable2FA() did not enable 2FA")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Enable2FA_InvalidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup 2FA
|
||||
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Try to enable with invalid code
|
||||
err = tfaAuth.Enable2FA(1, secret.Secret, "000000")
|
||||
if err == nil {
|
||||
t.Error("Enable2FA() should fail with invalid code")
|
||||
}
|
||||
|
||||
// Verify 2FA is not enabled
|
||||
status, _ := provider.Get2FAStatus(1)
|
||||
if status {
|
||||
t.Error("Enable2FA() should not enable 2FA with invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_Without2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA when not enabled")
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when 2FA not required")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_NoCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Try to login without 2FA code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if !resp.Requires2FA {
|
||||
t.Error("Login() should require 2FA when enabled")
|
||||
}
|
||||
|
||||
if resp.Token != "" {
|
||||
t.Error("Login() should not return token when 2FA required but not provided")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_ValidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Generate new valid code for login
|
||||
newCode, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
|
||||
// Login with 2FA code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: newCode,
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA when valid code provided")
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when 2FA validated")
|
||||
}
|
||||
|
||||
if !resp.User.TwoFactorEnabled {
|
||||
t.Error("Login() should set TwoFactorEnabled on user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_With2FA_InvalidCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Try to login with invalid code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: "000000",
|
||||
}
|
||||
|
||||
_, err := tfaAuth.Login(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail with invalid 2FA code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Login_WithBackupCode(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Get backup codes
|
||||
backupCodes, _ := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
|
||||
// Login with backup code
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: backupCodes[0],
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() with backup code error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token when backup code validated")
|
||||
}
|
||||
|
||||
// Try to use same backup code again
|
||||
req2 := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: backupCodes[0],
|
||||
}
|
||||
|
||||
_, err = tfaAuth.Login(context.Background(), req2)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail when reusing backup code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_Disable2FA(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Disable 2FA
|
||||
err := tfaAuth.Disable2FA(1)
|
||||
if err != nil {
|
||||
t.Errorf("Disable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify 2FA is disabled
|
||||
status, _ := provider.Get2FAStatus(1)
|
||||
if status {
|
||||
t.Error("Disable2FA() did not disable 2FA")
|
||||
}
|
||||
|
||||
// Login should not require 2FA
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Requires2FA {
|
||||
t.Error("Login() should not require 2FA after disabling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoFactorAuthenticator_RegenerateBackupCodes(t *testing.T) {
|
||||
baseAuth := NewMockAuthenticator()
|
||||
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||
|
||||
// Setup and enable 2FA
|
||||
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||
totp := security.NewTOTPGenerator(nil)
|
||||
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||
|
||||
// Get initial backup codes
|
||||
codes1, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(codes1) != 10 {
|
||||
t.Errorf("RegenerateBackupCodes() returned %d codes, want 10", len(codes1))
|
||||
}
|
||||
|
||||
// Regenerate backup codes
|
||||
codes2, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
// Old codes should not work
|
||||
req := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: codes1[0],
|
||||
}
|
||||
|
||||
_, err = tfaAuth.Login(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Error("Login() should fail with old backup code after regeneration")
|
||||
}
|
||||
|
||||
// New codes should work
|
||||
req2 := security.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
TwoFactorCode: codes2[0],
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(context.Background(), req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Login() with new backup code error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Token == "" {
|
||||
t.Error("Login() should return token with new backup code")
|
||||
}
|
||||
}
|
||||
134
pkg/security/totp_middleware.go
Normal file
134
pkg/security/totp_middleware.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// TwoFactorAuthenticator wraps an Authenticator and adds 2FA support
|
||||
type TwoFactorAuthenticator struct {
|
||||
baseAuth Authenticator
|
||||
totp *TOTPGenerator
|
||||
provider TwoFactorAuthProvider
|
||||
}
|
||||
|
||||
// NewTwoFactorAuthenticator creates a new 2FA-enabled authenticator
|
||||
func NewTwoFactorAuthenticator(baseAuth Authenticator, provider TwoFactorAuthProvider, config *TwoFactorConfig) *TwoFactorAuthenticator {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &TwoFactorAuthenticator{
|
||||
baseAuth: baseAuth,
|
||||
totp: NewTOTPGenerator(config),
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates with 2FA support
|
||||
func (t *TwoFactorAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// First, perform standard authentication
|
||||
resp, err := t.baseAuth.Login(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if user has 2FA enabled
|
||||
if resp.User == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
has2FA, err := t.provider.Get2FAStatus(resp.User.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check 2FA status: %w", err)
|
||||
}
|
||||
|
||||
if !has2FA {
|
||||
// User doesn't have 2FA enabled, return normal response
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// User has 2FA enabled
|
||||
if req.TwoFactorCode == "" {
|
||||
// No 2FA code provided, require it
|
||||
resp.Requires2FA = true
|
||||
resp.Token = "" // Don't return token until 2FA is verified
|
||||
resp.RefreshToken = ""
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Validate 2FA code
|
||||
secret, err := t.provider.Get2FASecret(resp.User.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get 2FA secret: %w", err)
|
||||
}
|
||||
|
||||
// Try TOTP code first
|
||||
valid, err := t.totp.ValidateCode(secret, req.TwoFactorCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate 2FA code: %w", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
// Try backup code
|
||||
valid, err = t.provider.ValidateBackupCode(resp.User.UserID, req.TwoFactorCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate backup code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid 2FA code")
|
||||
}
|
||||
|
||||
// 2FA verified, return full response with token
|
||||
resp.User.TwoFactorEnabled = true
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Logout delegates to base authenticator
|
||||
func (t *TwoFactorAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
return t.baseAuth.Logout(ctx, req)
|
||||
}
|
||||
|
||||
// Authenticate delegates to base authenticator
|
||||
func (t *TwoFactorAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
return t.baseAuth.Authenticate(r)
|
||||
}
|
||||
|
||||
// Setup2FA initiates 2FA setup for a user
|
||||
func (t *TwoFactorAuthenticator) Setup2FA(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
return t.provider.Generate2FASecret(userID, issuer, accountName)
|
||||
}
|
||||
|
||||
// Enable2FA completes 2FA setup after user confirms with a valid code
|
||||
func (t *TwoFactorAuthenticator) Enable2FA(userID int, secret, verificationCode string) error {
|
||||
// Verify the code before enabling
|
||||
valid, err := t.totp.ValidateCode(secret, verificationCode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate code: %w", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid verification code")
|
||||
}
|
||||
|
||||
// Generate backup codes
|
||||
backupCodes, err := t.provider.GenerateBackupCodes(userID, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Enable 2FA
|
||||
return t.provider.Enable2FA(userID, secret, backupCodes)
|
||||
}
|
||||
|
||||
// Disable2FA removes 2FA from a user account
|
||||
func (t *TwoFactorAuthenticator) Disable2FA(userID int) error {
|
||||
return t.provider.Disable2FA(userID)
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes creates new backup codes for a user
|
||||
func (t *TwoFactorAuthenticator) RegenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
return t.provider.GenerateBackupCodes(userID, count)
|
||||
}
|
||||
229
pkg/security/totp_provider_database.go
Normal file
229
pkg/security/totp_provider_database.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
func (p *DatabaseTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
secret, err := p.totpGen.GenerateSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate secret: %w", err)
|
||||
}
|
||||
|
||||
qrURL := p.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
backupCodes, err := GenerateBackupCodes(10)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
return &TwoFactorSecret{
|
||||
Secret: secret,
|
||||
QRCodeURL: qrURL,
|
||||
BackupCodes: backupCodes,
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
func (p *DatabaseTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||
return p.totpGen.ValidateCode(secret, code)
|
||||
}
|
||||
|
||||
// Enable2FA activates 2FA for a user
|
||||
func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
// Hash backup codes for secure storage
|
||||
hashedCodes := make([]string, len(backupCodes))
|
||||
for i, code := range backupCodes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Convert to JSON array
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Call stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to enable 2FA")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("failed to disable 2FA")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return false, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return false, fmt.Errorf("failed to get 2FA status")
|
||||
}
|
||||
|
||||
return enabled, nil
|
||||
}
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return "", fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return "", fmt.Errorf("failed to get 2FA secret")
|
||||
}
|
||||
|
||||
if !secret.Valid {
|
||||
return "", fmt.Errorf("2FA secret not found")
|
||||
}
|
||||
|
||||
return secret.String, nil
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Hash backup codes for storage
|
||||
hashedCodes := make([]string, len(codes))
|
||||
for i, code := range codes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Convert to JSON array
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||
}
|
||||
|
||||
// Call stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to regenerate backup codes")
|
||||
}
|
||||
|
||||
// Return unhashed codes to user (only time they see them)
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||
// Hash the code
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
codeHash := hex.EncodeToString(hash[:])
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return false, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
218
pkg/security/totp_provider_database_test.go
Normal file
218
pkg/security/totp_provider_database_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package security_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Note: These tests require a PostgreSQL database with the schema from totp_database_schema.sql
|
||||
// Set TEST_DATABASE_URL environment variable or skip tests
|
||||
|
||||
func setupTestDB(t *testing.T) *sql.DB {
|
||||
// Skip if no test database configured
|
||||
t.Skip("Database tests require TEST_DATABASE_URL environment variable")
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Enable2FA(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Generate secret and backup codes
|
||||
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
// Enable 2FA
|
||||
err = provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
if err != nil {
|
||||
t.Errorf("Enable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify enabled
|
||||
enabled, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
t.Error("Get2FAStatus() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Disable2FA(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable first
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Disable
|
||||
err := provider.Disable2FA(1)
|
||||
if err != nil {
|
||||
t.Errorf("Disable2FA() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify disabled
|
||||
enabled, err := provider.Get2FAStatus(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||
}
|
||||
|
||||
if enabled {
|
||||
t.Error("Get2FAStatus() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_GetSecret(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Retrieve secret
|
||||
retrieved, err := provider.Get2FASecret(1)
|
||||
if err != nil {
|
||||
t.Errorf("Get2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
if retrieved != secret.Secret {
|
||||
t.Errorf("Get2FASecret() = %v, want %v", retrieved, secret.Secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_ValidateBackupCode(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Validate backup code
|
||||
valid, err := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateBackupCode() = false, want true")
|
||||
}
|
||||
|
||||
// Try to use same code again
|
||||
valid, err = provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if err == nil {
|
||||
t.Error("ValidateBackupCode() should error on reuse")
|
||||
}
|
||||
|
||||
// Try invalid code
|
||||
valid, err = provider.ValidateBackupCode(1, "INVALID")
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("ValidateBackupCode() = true for invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_RegenerateBackupCodes(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
// Enable 2FA
|
||||
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||
|
||||
// Regenerate codes
|
||||
newCodes, err := provider.GenerateBackupCodes(1, 10)
|
||||
if err != nil {
|
||||
t.Errorf("GenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(newCodes) != 10 {
|
||||
t.Errorf("GenerateBackupCodes() returned %d codes, want 10", len(newCodes))
|
||||
}
|
||||
|
||||
// Old codes should not work
|
||||
valid, _ := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||
if valid {
|
||||
t.Error("Old backup code should not work after regeneration")
|
||||
}
|
||||
|
||||
// New codes should work
|
||||
valid, err = provider.ValidateBackupCode(1, newCodes[0])
|
||||
if err != nil {
|
||||
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateBackupCode() = false for new code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseTwoFactorProvider_Generate2FASecret(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
|
||||
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||
}
|
||||
|
||||
if secret.Secret == "" {
|
||||
t.Error("Generate2FASecret() returned empty secret")
|
||||
}
|
||||
|
||||
if secret.QRCodeURL == "" {
|
||||
t.Error("Generate2FASecret() returned empty QR code URL")
|
||||
}
|
||||
|
||||
if len(secret.BackupCodes) != 10 {
|
||||
t.Errorf("Generate2FASecret() returned %d backup codes, want 10", len(secret.BackupCodes))
|
||||
}
|
||||
|
||||
if secret.Issuer != "TestApp" {
|
||||
t.Errorf("Generate2FASecret() Issuer = %v, want TestApp", secret.Issuer)
|
||||
}
|
||||
|
||||
if secret.AccountName != "test@example.com" {
|
||||
t.Errorf("Generate2FASecret() AccountName = %v, want test@example.com", secret.AccountName)
|
||||
}
|
||||
}
|
||||
156
pkg/security/totp_provider_memory.go
Normal file
156
pkg/security/totp_provider_memory.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryTwoFactorProvider is an in-memory implementation of TwoFactorAuthProvider for testing/examples
|
||||
type MemoryTwoFactorProvider struct {
|
||||
mu sync.RWMutex
|
||||
secrets map[int]string // userID -> secret
|
||||
backupCodes map[int]map[string]bool // userID -> backup codes (code -> used)
|
||||
totpGen *TOTPGenerator
|
||||
}
|
||||
|
||||
// NewMemoryTwoFactorProvider creates a new in-memory 2FA provider
|
||||
func NewMemoryTwoFactorProvider(config *TwoFactorConfig) *MemoryTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &MemoryTwoFactorProvider{
|
||||
secrets: make(map[int]string),
|
||||
backupCodes: make(map[int]map[string]bool),
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate2FASecret creates a new secret for a user
|
||||
func (m *MemoryTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||
secret, err := m.totpGen.GenerateSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qrURL := m.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
backupCodes, err := GenerateBackupCodes(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TwoFactorSecret{
|
||||
Secret: secret,
|
||||
QRCodeURL: qrURL,
|
||||
BackupCodes: backupCodes,
|
||||
Issuer: issuer,
|
||||
AccountName: accountName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate2FACode verifies a TOTP code
|
||||
func (m *MemoryTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||
return m.totpGen.ValidateCode(secret, code)
|
||||
}
|
||||
|
||||
// Enable2FA activates 2FA for a user
|
||||
func (m *MemoryTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.secrets[userID] = secret
|
||||
|
||||
// Store backup codes
|
||||
if m.backupCodes[userID] == nil {
|
||||
m.backupCodes[userID] = make(map[string]bool)
|
||||
}
|
||||
|
||||
for _, code := range backupCodes {
|
||||
// Hash backup codes for security
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable2FA deactivates 2FA for a user
|
||||
func (m *MemoryTwoFactorProvider) Disable2FA(userID int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.secrets, userID)
|
||||
delete(m.backupCodes, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get2FAStatus checks if user has 2FA enabled
|
||||
func (m *MemoryTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
_, exists := m.secrets[userID]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Get2FASecret retrieves the user's 2FA secret
|
||||
func (m *MemoryTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
secret, exists := m.secrets[userID]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("user does not have 2FA enabled")
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GenerateBackupCodes creates backup codes for 2FA
|
||||
func (m *MemoryTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Clear old backup codes and store new ones
|
||||
m.backupCodes[userID] = make(map[string]bool)
|
||||
for _, code := range codes {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||
}
|
||||
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// ValidateBackupCode checks and consumes a backup code
|
||||
func (m *MemoryTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
userCodes, exists := m.backupCodes[userID]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Hash the provided code
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
|
||||
used, exists := userCodes[hashStr]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if used {
|
||||
return false, fmt.Errorf("backup code already used")
|
||||
}
|
||||
|
||||
// Mark as used
|
||||
userCodes[hashStr] = true
|
||||
return true, nil
|
||||
}
|
||||
292
pkg/security/totp_test.go
Normal file
292
pkg/security/totp_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTOTPGenerator_GenerateSecret(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
secret, err := totp.GenerateSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSecret() error = %v", err)
|
||||
}
|
||||
|
||||
if secret == "" {
|
||||
t.Error("GenerateSecret() returned empty secret")
|
||||
}
|
||||
|
||||
// Secret should be base32 encoded
|
||||
if len(secret) < 16 {
|
||||
t.Error("GenerateSecret() returned secret that is too short")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_GenerateQRCodeURL(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
issuer := "TestApp"
|
||||
accountName := "user@example.com"
|
||||
|
||||
url := totp.GenerateQRCodeURL(secret, issuer, accountName)
|
||||
|
||||
if !strings.HasPrefix(url, "otpauth://totp/") {
|
||||
t.Errorf("GenerateQRCodeURL() = %v, want otpauth://totp/ prefix", url)
|
||||
}
|
||||
|
||||
if !strings.Contains(url, "secret="+secret) {
|
||||
t.Errorf("GenerateQRCodeURL() missing secret parameter")
|
||||
}
|
||||
|
||||
if !strings.Contains(url, "issuer="+issuer) {
|
||||
t.Errorf("GenerateQRCodeURL() missing issuer parameter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_GenerateCode(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Test with known time
|
||||
timestamp := time.Unix(1234567890, 0)
|
||||
code, err := totp.GenerateCode(secret, timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if len(code) != 6 {
|
||||
t.Errorf("GenerateCode() returned code with length %d, want 6", len(code))
|
||||
}
|
||||
|
||||
// Code should be numeric
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("GenerateCode() returned non-numeric code: %s", code)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_ValidateCode(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Generate a code for current time
|
||||
now := time.Now()
|
||||
code, err := totp.GenerateCode(secret, now)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate the code
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for current code")
|
||||
}
|
||||
|
||||
// Test with invalid code
|
||||
valid, err = totp.ValidateCode(secret, "000000")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// This might occasionally pass if 000000 is the correct code, but very unlikely
|
||||
if valid && code != "000000" {
|
||||
t.Error("ValidateCode() = true for invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_ValidateCode_WithSkew(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 2, // Allow 2 periods before/after
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
// Generate code for 1 period ago
|
||||
past := time.Now().Add(-30 * time.Second)
|
||||
code, err := totp.GenerateCode(secret, past)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
// Should still validate with skew window
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for code within skew window")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_DifferentAlgorithms(t *testing.T) {
|
||||
algorithms := []string{"SHA1", "SHA256", "SHA512"}
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
for _, algo := range algorithms {
|
||||
t.Run(algo, func(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: algo,
|
||||
Digits: 6,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
code, err := totp.GenerateCode(secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() with %s error = %v", algo, err)
|
||||
}
|
||||
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() with %s error = %v", algo, err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Errorf("ValidateCode() with %s = false, want true", algo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_8Digits(t *testing.T) {
|
||||
config := &TwoFactorConfig{
|
||||
Algorithm: "SHA1",
|
||||
Digits: 8,
|
||||
Period: 30,
|
||||
SkewWindow: 1,
|
||||
}
|
||||
totp := NewTOTPGenerator(config)
|
||||
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
|
||||
code, err := totp.GenerateCode(secret, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if len(code) != 8 {
|
||||
t.Errorf("GenerateCode() returned code with length %d, want 8", len(code))
|
||||
}
|
||||
|
||||
valid, err := totp.ValidateCode(secret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateCode() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("ValidateCode() = false, want true for 8-digit code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateBackupCodes(t *testing.T) {
|
||||
count := 10
|
||||
codes, err := GenerateBackupCodes(count)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateBackupCodes() error = %v", err)
|
||||
}
|
||||
|
||||
if len(codes) != count {
|
||||
t.Errorf("GenerateBackupCodes() returned %d codes, want %d", len(codes), count)
|
||||
}
|
||||
|
||||
// Check uniqueness
|
||||
seen := make(map[string]bool)
|
||||
for _, code := range codes {
|
||||
if seen[code] {
|
||||
t.Errorf("GenerateBackupCodes() generated duplicate code: %s", code)
|
||||
}
|
||||
seen[code] = true
|
||||
|
||||
// Check format (8 hex characters)
|
||||
if len(code) != 8 {
|
||||
t.Errorf("GenerateBackupCodes() code length = %d, want 8", len(code))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTwoFactorConfig(t *testing.T) {
|
||||
config := DefaultTwoFactorConfig()
|
||||
|
||||
if config.Algorithm != "SHA1" {
|
||||
t.Errorf("DefaultTwoFactorConfig() Algorithm = %s, want SHA1", config.Algorithm)
|
||||
}
|
||||
|
||||
if config.Digits != 6 {
|
||||
t.Errorf("DefaultTwoFactorConfig() Digits = %d, want 6", config.Digits)
|
||||
}
|
||||
|
||||
if config.Period != 30 {
|
||||
t.Errorf("DefaultTwoFactorConfig() Period = %d, want 30", config.Period)
|
||||
}
|
||||
|
||||
if config.SkewWindow != 1 {
|
||||
t.Errorf("DefaultTwoFactorConfig() SkewWindow = %d, want 1", config.SkewWindow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPGenerator_InvalidSecret(t *testing.T) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
|
||||
// Test with invalid base32 secret
|
||||
_, err := totp.GenerateCode("INVALID!!!", time.Now())
|
||||
if err == nil {
|
||||
t.Error("GenerateCode() with invalid secret should return error")
|
||||
}
|
||||
|
||||
_, err = totp.ValidateCode("INVALID!!!", "123456")
|
||||
if err == nil {
|
||||
t.Error("ValidateCode() with invalid secret should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkTOTPGenerator_GenerateCode(b *testing.B) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
now := time.Now()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = totp.GenerateCode(secret, now)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTOTPGenerator_ValidateCode(b *testing.B) {
|
||||
totp := NewTOTPGenerator(nil)
|
||||
secret := "JBSWY3DPEHPK3PXP"
|
||||
code, _ := totp.GenerateCode(secret, time.Now())
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = totp.ValidateCode(secret, code)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -209,10 +210,14 @@ func (h *Handler) handleRead(conn *Connection, msg *Message, hookCtx *HookContex
|
||||
var metadata map[string]interface{}
|
||||
var err error
|
||||
|
||||
if hookCtx.ID != "" {
|
||||
// Read single record by ID
|
||||
// Check if FetchRowNumber is specified (treat as single record read)
|
||||
isFetchRowNumber := hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != ""
|
||||
|
||||
if hookCtx.ID != "" || isFetchRowNumber {
|
||||
// Read single record by ID or FetchRowNumber
|
||||
data, err = h.readByID(hookCtx)
|
||||
metadata = map[string]interface{}{"total": 1}
|
||||
// The row number is already set on the record itself via setRowNumbersOnRecords
|
||||
} else {
|
||||
// Read multiple records
|
||||
data, metadata, err = h.readMultiple(hookCtx)
|
||||
@@ -509,10 +514,29 @@ func (h *Handler) notifySubscribers(schema, entity string, operation OperationTy
|
||||
// CRUD operation implementations
|
||||
|
||||
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
// Handle FetchRowNumber before building query
|
||||
var fetchedRowNumber *int64
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
|
||||
if hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != "" {
|
||||
fetchRowNumberPKValue := *hookCtx.Options.FetchRowNumber
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(hookCtx.Context, hookCtx.TableName, pkName, fetchRowNumberPKValue, hookCtx.Options, hookCtx.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
fetchedRowNumber = &rowNum
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||
|
||||
// Override ID with FetchRowNumber value
|
||||
hookCtx.ID = fetchRowNumberPKValue
|
||||
}
|
||||
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
// Apply columns
|
||||
@@ -532,6 +556,12 @@ func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to read record: %w", err)
|
||||
}
|
||||
|
||||
// Set the fetched row number on the record if FetchRowNumber was used
|
||||
if fetchedRowNumber != nil {
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||
h.setRowNumbersOnRecords(hookCtx.ModelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
@@ -540,10 +570,8 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
|
||||
// Apply options (simplified implementation)
|
||||
if hookCtx.Options != nil {
|
||||
// Apply filters
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
// Apply filters with OR grouping support
|
||||
query = h.applyFilters(query, hookCtx.Options.Filters)
|
||||
|
||||
// Apply sorting
|
||||
for _, sort := range hookCtx.Options.Sort {
|
||||
@@ -578,6 +606,13 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
return nil, nil, fmt.Errorf("failed to read records: %w", err)
|
||||
}
|
||||
|
||||
// Set row numbers on records if RowNumber field exists
|
||||
offset := 0
|
||||
if hookCtx.Options != nil && hookCtx.Options.Offset != nil {
|
||||
offset = *hookCtx.Options.Offset
|
||||
}
|
||||
h.setRowNumbersOnRecords(hookCtx.ModelPtr, offset)
|
||||
|
||||
// Get count
|
||||
metadata = make(map[string]interface{})
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
@@ -656,11 +691,14 @@ func (h *Handler) delete(hookCtx *HookContext) error {
|
||||
// Helper methods
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
tableName = schema + "_" + tableName
|
||||
} else {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
@@ -680,6 +718,133 @@ func (h *Handler) getMetadata(schema, entity string, model interface{}) map[stri
|
||||
}
|
||||
|
||||
// getOperatorSQL converts filter operator to SQL operator
|
||||
// applyFilters applies all filters with proper grouping for OR logic
|
||||
// Groups consecutive OR filters together to ensure proper query precedence
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped WHERE clause
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build all conditions and collect args
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Single filter - no need for grouping
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
|
||||
// Multiple conditions - group with parentheses and OR
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
|
||||
args = []interface{}{filter.Value}
|
||||
|
||||
return condition, args
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("[WebSocketSpec] setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("[WebSocketSpec] Set RowNumber=%d for record index %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) getOperatorSQL(operator string) string {
|
||||
switch operator {
|
||||
case "eq":
|
||||
@@ -705,6 +870,92 @@ func (h *Handler) getOperatorSQL(operator string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// FetchRowNumber calculates the row number of a specific record based on sorting and filtering
|
||||
// Returns the 1-based row number of the record with the given primary key value
|
||||
func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options *common.RequestOptions, model interface{}) (int64, error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("[WebSocketSpec] Panic during FetchRowNumber: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Build the sort order SQL
|
||||
sortSQL := ""
|
||||
if options != nil && len(options.Sort) > 0 {
|
||||
sortParts := make([]string, 0, len(options.Sort))
|
||||
for _, sort := range options.Sort {
|
||||
if sort.Column == "" {
|
||||
continue
|
||||
}
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
sortSQL = strings.Join(sortParts, ", ")
|
||||
} else {
|
||||
// Default sort by primary key
|
||||
sortSQL = fmt.Sprintf("%s ASC", pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clause from filters
|
||||
whereSQL := ""
|
||||
var whereArgs []interface{}
|
||||
if options != nil && len(options.Filters) > 0 {
|
||||
var conditions []string
|
||||
for _, filter := range options.Filters {
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
conditions = append(conditions, fmt.Sprintf("%s.%s %s ?", tableName, filter.Column, operatorSQL))
|
||||
whereArgs = append(whereArgs, filter.Value)
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
}
|
||||
|
||||
// Build the final query with parameterized PK value
|
||||
queryStr := fmt.Sprintf(`
|
||||
SELECT search.rn
|
||||
FROM (
|
||||
SELECT %[1]s.%[2]s,
|
||||
ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn
|
||||
FROM %[1]s
|
||||
%[4]s
|
||||
) search
|
||||
WHERE search.%[2]s = ?
|
||||
`,
|
||||
tableName, // [1] - table name
|
||||
pkName, // [2] - primary key column name
|
||||
sortSQL, // [3] - sort order SQL
|
||||
whereSQL, // [4] - WHERE clause
|
||||
)
|
||||
|
||||
logger.Debug("[WebSocketSpec] FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue)
|
||||
|
||||
// Append PK value to whereArgs
|
||||
whereArgs = append(whereArgs, pkValue)
|
||||
|
||||
// Execute the raw query with parameterized PK value
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
err := h.db.Query(ctx, &result, queryStr, whereArgs...)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
whereInfo := "none"
|
||||
if whereSQL != "" {
|
||||
whereInfo = whereSQL
|
||||
}
|
||||
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler
|
||||
func (h *Handler) Shutdown() {
|
||||
h.connManager.Shutdown()
|
||||
|
||||
@@ -82,6 +82,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return args.Get(0)
|
||||
}
|
||||
|
||||
func (m *MockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// MockSelectQuery is a mock implementation of common.SelectQuery
|
||||
type MockSelectQuery struct {
|
||||
mock.Mock
|
||||
|
||||
8
resolvespec-js/.changeset/README.md
Normal file
8
resolvespec-js/.changeset/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Changesets
|
||||
|
||||
Hello and welcome! This folder has been automatically generated by `@changesets/cli`, a build tool that works
|
||||
with multi-package repos, or single-package repos to help you version and publish your code. You can
|
||||
find the full documentation for it [in our repository](https://github.com/changesets/changesets)
|
||||
|
||||
We have a quick list of common questions to get you started engaging with this project in
|
||||
[our documentation](https://github.com/changesets/changesets/blob/main/docs/common-questions.md)
|
||||
11
resolvespec-js/.changeset/config.json
Normal file
11
resolvespec-js/.changeset/config.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"$schema": "https://unpkg.com/@changesets/config@3.1.2/schema.json",
|
||||
"changelog": "@changesets/cli/changelog",
|
||||
"commit": false,
|
||||
"fixed": [],
|
||||
"linked": [],
|
||||
"access": "restricted",
|
||||
"baseBranch": "main",
|
||||
"updateInternalDependencies": "patch",
|
||||
"ignore": []
|
||||
}
|
||||
7
resolvespec-js/CHANGELOG.md
Normal file
7
resolvespec-js/CHANGELOG.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# @warkypublic/resolvespec-js
|
||||
|
||||
## 1.0.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- Fixed headerpsec
|
||||
132
resolvespec-js/PLAN.md
Normal file
132
resolvespec-js/PLAN.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# ResolveSpec JS - Implementation Plan
|
||||
|
||||
TypeScript client library for ResolveSpec, RestHeaderSpec, WebSocket and MQTT APIs.
|
||||
|
||||
---
|
||||
|
||||
## Status
|
||||
|
||||
| Phase | Description | Status |
|
||||
|-------|-------------|--------|
|
||||
| 0 | Restructure into folders | Done |
|
||||
| 1 | Fix types (align with Go) | Done |
|
||||
| 2 | Fix REST client | Done |
|
||||
| 3 | Build config | Done |
|
||||
| 4 | Tests | Done |
|
||||
| 5 | HeaderSpec client | Done |
|
||||
| 6 | MQTT client | Planned |
|
||||
| 6.5 | Unified class pattern + singleton factories | Done |
|
||||
| 7 | Response cache (TTL) | Planned |
|
||||
| 8 | TanStack Query integration | Planned |
|
||||
| 9 | React Hooks | Planned |
|
||||
|
||||
**Build:** `dist/index.js` (ES) + `dist/index.cjs` (CJS) + `.d.ts` declarations
|
||||
**Tests:** 65 passing (common: 10, resolvespec: 13, websocketspec: 15, headerspec: 27)
|
||||
|
||||
---
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── common/
|
||||
│ ├── types.ts # Core types aligned with Go pkg/common/types.go
|
||||
│ └── index.ts
|
||||
├── resolvespec/
|
||||
│ ├── client.ts # ResolveSpecClient class + createResolveSpecClient singleton
|
||||
│ └── index.ts
|
||||
├── headerspec/
|
||||
│ ├── client.ts # HeaderSpecClient class + createHeaderSpecClient singleton + buildHeaders utility
|
||||
│ └── index.ts
|
||||
├── websocketspec/
|
||||
│ ├── types.ts # WS-specific types (WSMessage, WSOptions, etc.)
|
||||
│ ├── client.ts # WebSocketClient class + createWebSocketClient singleton
|
||||
│ └── index.ts
|
||||
├── mqttspec/ # Future
|
||||
│ ├── types.ts
|
||||
│ ├── client.ts
|
||||
│ └── index.ts
|
||||
├── __tests__/
|
||||
│ ├── common.test.ts
|
||||
│ ├── resolvespec.test.ts
|
||||
│ ├── headerspec.test.ts
|
||||
│ └── websocketspec.test.ts
|
||||
└── index.ts # Root barrel export
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Type Alignment with Go
|
||||
|
||||
Types in `src/common/types.ts` match `pkg/common/types.go`:
|
||||
|
||||
- **Operator**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, `contains`, `startswith`, `endswith`, `between`, `between_inclusive`, `is_null`, `is_not_null`
|
||||
- **FilterOption**: `column`, `operator`, `value`, `logic_operator` (AND/OR)
|
||||
- **Options**: `columns`, `omit_columns`, `filters`, `sort`, `limit`, `offset`, `preload`, `customOperators`, `computedColumns`, `parameters`, `cursor_forward`, `cursor_backward`, `fetch_row_number`
|
||||
- **PreloadOption**: `relation`, `table_name`, `columns`, `omit_columns`, `sort`, `filters`, `where`, `limit`, `offset`, `updatable`, `recursive`, `computed_ql`, `primary_key`, `related_key`, `foreign_key`, `recursive_child_key`, `sql_joins`, `join_aliases`
|
||||
- **Parameter**: `name`, `value`, `sequence?`
|
||||
- **Metadata**: `total`, `count`, `filtered`, `limit`, `offset`, `row_number?`
|
||||
- **APIError**: `code`, `message`, `details?`, `detail?`
|
||||
|
||||
---
|
||||
|
||||
## HeaderSpec Header Mapping
|
||||
|
||||
Maps Options to HTTP headers per Go `restheadspec/headers.go`:
|
||||
|
||||
| Header | Options field | Format |
|
||||
|--------|--------------|--------|
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` (asc), `-col` (desc) |
|
||||
| `X-Limit` | `limit` | number |
|
||||
| `X-Offset` | `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
Complex values use `ZIP_` + base64 encoding.
|
||||
HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete.
|
||||
|
||||
---
|
||||
|
||||
## Build & Test
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # vite library mode → dist/
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
**Config files:** `tsconfig.json` (ES2020, strict, bundler), `vite.config.ts` (lib mode, dts via vite-plugin-dts)
|
||||
**Externals:** `uuid`, `semver`
|
||||
|
||||
---
|
||||
|
||||
## Remaining Work
|
||||
|
||||
- **Phase 6 — MQTT Client**: Topic-based CRUD over MQTT (optional/future)
|
||||
- **Phase 7 — Cache**: In-memory response cache with TTL, key = URL + options hash, auto-invalidation on CUD, `skipCache` flag
|
||||
- **Phase 8 — TanStack Query Integration**: Query/mutation hooks wrapping each client, query key factories, automatic cache invalidation
|
||||
- **Phase 9 — React Hooks**: `useResolveSpec`, `useHeaderSpec`, `useWebSocket` hooks with provider context, loading/error states
|
||||
- ESLint config may need updating for new folder structure
|
||||
|
||||
---
|
||||
|
||||
## Reference Files
|
||||
|
||||
| Purpose | Path |
|
||||
|---------|------|
|
||||
| Go types (source of truth) | `pkg/common/types.go` |
|
||||
| Go REST handler | `pkg/resolvespec/handler.go` |
|
||||
| Go HeaderSpec handler | `pkg/restheadspec/handler.go` |
|
||||
| Go HeaderSpec header parsing | `pkg/restheadspec/headers.go` |
|
||||
| Go test models | `pkg/testmodels/business.go` |
|
||||
| Go tests | `tests/crud_test.go` |
|
||||
213
resolvespec-js/README.md
Normal file
213
resolvespec-js/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# ResolveSpec JS
|
||||
|
||||
TypeScript client library for ResolveSpec APIs. Supports body-based REST, header-based REST, and WebSocket protocols.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Clients
|
||||
|
||||
| Client | Protocol | Singleton Factory |
|
||||
| --- | --- | --- |
|
||||
| `ResolveSpecClient` | REST (body-based) | `getResolveSpecClient(config)` |
|
||||
| `HeaderSpecClient` | REST (header-based) | `getHeaderSpecClient(config)` |
|
||||
| `WebSocketClient` | WebSocket | `getWebSocketClient(config)` |
|
||||
|
||||
All clients use the class pattern. Singleton factories return cached instances keyed by URL.
|
||||
|
||||
## REST Client (Body-Based)
|
||||
|
||||
Options sent in JSON request body. Maps to Go `pkg/resolvespec`.
|
||||
|
||||
```typescript
|
||||
import { ResolveSpecClient, getResolveSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Class instantiation
|
||||
const client = new ResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Or singleton factory (returns cached instance per baseUrl)
|
||||
const client = getResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Read with filters, sort, pagination
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name', 'email'],
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
preload: [{ relation: 'Posts', columns: ['id', 'title'] }],
|
||||
});
|
||||
|
||||
// Read by ID
|
||||
const user = await client.read('public', 'users', 42);
|
||||
|
||||
// Create
|
||||
const created = await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// Update
|
||||
await client.update('public', 'users', { name: 'Updated' }, 42);
|
||||
|
||||
// Delete
|
||||
await client.delete('public', 'users', 42);
|
||||
|
||||
// Metadata
|
||||
const meta = await client.getMetadata('public', 'users');
|
||||
```
|
||||
|
||||
## HeaderSpec Client (Header-Based)
|
||||
|
||||
Options sent via HTTP headers. Maps to Go `pkg/restheadspec`.
|
||||
|
||||
```typescript
|
||||
import { HeaderSpecClient, getHeaderSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const client = new HeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
// Or: const client = getHeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// GET with options as headers
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' },
|
||||
{ column: 'age', operator: 'gte', value: 18, logic_operator: 'AND' },
|
||||
],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 50,
|
||||
preload: [{ relation: 'Department', columns: ['id', 'name'] }],
|
||||
});
|
||||
|
||||
// POST create
|
||||
await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// PUT update
|
||||
await client.update('public', 'users', '42', { name: 'Updated' });
|
||||
|
||||
// DELETE
|
||||
await client.delete('public', 'users', '42');
|
||||
```
|
||||
|
||||
### Header Mapping
|
||||
|
||||
| Header | Options Field | Format |
|
||||
| --- | --- | --- |
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` asc, `-col` desc |
|
||||
| `X-Limit` / `X-Offset` | `limit` / `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```typescript
|
||||
import { buildHeaders, encodeHeaderValue, decodeHeaderValue } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const headers = buildHeaders({ columns: ['id', 'name'], limit: 10 });
|
||||
// => { 'X-Select-Fields': 'id,name', 'X-Limit': '10' }
|
||||
|
||||
const encoded = encodeHeaderValue('complex value'); // 'ZIP_...'
|
||||
const decoded = decodeHeaderValue(encoded); // 'complex value'
|
||||
```
|
||||
|
||||
## WebSocket Client
|
||||
|
||||
Real-time CRUD with subscriptions. Maps to Go `pkg/websocketspec`.
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient, getWebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const ws = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
heartbeatInterval: 30000,
|
||||
});
|
||||
// Or: const ws = getWebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
|
||||
await ws.connect();
|
||||
|
||||
// CRUD
|
||||
const users = await ws.read('users', { schema: 'public', limit: 10 });
|
||||
const created = await ws.create('users', { name: 'New' }, { schema: 'public' });
|
||||
await ws.update('users', '1', { name: 'Updated' });
|
||||
await ws.delete('users', '1');
|
||||
|
||||
// Subscribe to changes
|
||||
const subId = await ws.subscribe('users', (notification) => {
|
||||
console.log(notification.operation, notification.data);
|
||||
});
|
||||
|
||||
// Unsubscribe
|
||||
await ws.unsubscribe(subId);
|
||||
|
||||
// Events
|
||||
ws.on('connect', () => console.log('connected'));
|
||||
ws.on('disconnect', () => console.log('disconnected'));
|
||||
ws.on('error', (err) => console.error(err));
|
||||
|
||||
ws.disconnect();
|
||||
```
|
||||
|
||||
## Types
|
||||
|
||||
All types align with Go `pkg/common/types.go`.
|
||||
|
||||
### Key Types
|
||||
|
||||
```typescript
|
||||
interface Options {
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
preload?: PreloadOption[];
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
// Operators: eq, neq, gt, gte, lt, lte, like, ilike, in,
|
||||
// contains, startswith, endswith, between,
|
||||
// between_inclusive, is_null, is_not_null
|
||||
|
||||
interface APIResponse<T> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # dist/index.js (ES) + dist/index.cjs (CJS) + .d.ts
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
@@ -1,530 +0,0 @@
|
||||
# WebSocketSpec JavaScript Client
|
||||
|
||||
A TypeScript/JavaScript client for connecting to WebSocketSpec servers with full support for real-time subscriptions, CRUD operations, and automatic reconnection.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @warkypublic/resolvespec-js
|
||||
# or
|
||||
yarn add @warkypublic/resolvespec-js
|
||||
# or
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Create client
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
debug: true
|
||||
});
|
||||
|
||||
// Connect
|
||||
await client.connect();
|
||||
|
||||
// Read records
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
limit: 10
|
||||
});
|
||||
|
||||
// Subscribe to changes
|
||||
const subscriptionId = await client.subscribe('users', (notification) => {
|
||||
console.log('User changed:', notification.operation, notification.data);
|
||||
}, { schema: 'public' });
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(subscriptionId);
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Real-Time Updates**: Subscribe to entity changes and receive instant notifications
|
||||
- **Full CRUD Support**: Create, read, update, and delete operations
|
||||
- **TypeScript Support**: Full type definitions included
|
||||
- **Auto Reconnection**: Automatic reconnection with configurable retry logic
|
||||
- **Heartbeat**: Built-in keepalive mechanism
|
||||
- **Event System**: Listen to connection, error, and message events
|
||||
- **Promise-based API**: All async operations return promises
|
||||
- **Filter & Sort**: Advanced querying with filters, sorting, and pagination
|
||||
- **Preloading**: Load related entities in a single query
|
||||
|
||||
## Configuration
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws', // WebSocket server URL
|
||||
reconnect: true, // Enable auto-reconnection
|
||||
reconnectInterval: 3000, // Reconnection delay (ms)
|
||||
maxReconnectAttempts: 10, // Max reconnection attempts
|
||||
heartbeatInterval: 30000, // Heartbeat interval (ms)
|
||||
debug: false // Enable debug logging
|
||||
});
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Connection Management
|
||||
|
||||
#### `connect(): Promise<void>`
|
||||
Connect to the WebSocket server.
|
||||
|
||||
```typescript
|
||||
await client.connect();
|
||||
```
|
||||
|
||||
#### `disconnect(): void`
|
||||
Disconnect from the server.
|
||||
|
||||
```typescript
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
#### `isConnected(): boolean`
|
||||
Check if currently connected.
|
||||
|
||||
```typescript
|
||||
if (client.isConnected()) {
|
||||
console.log('Connected!');
|
||||
}
|
||||
```
|
||||
|
||||
#### `getState(): ConnectionState`
|
||||
Get current connection state: `'connecting'`, `'connected'`, `'disconnecting'`, `'disconnected'`, or `'reconnecting'`.
|
||||
|
||||
```typescript
|
||||
const state = client.getState();
|
||||
console.log('State:', state);
|
||||
```
|
||||
|
||||
### CRUD Operations
|
||||
|
||||
#### `read<T>(entity: string, options?): Promise<T>`
|
||||
Read records from an entity.
|
||||
|
||||
```typescript
|
||||
// Read all active users
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
columns: ['id', 'name', 'email'],
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' }
|
||||
],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
|
||||
// Read single record by ID
|
||||
const user = await client.read('users', {
|
||||
schema: 'public',
|
||||
record_id: '123'
|
||||
});
|
||||
|
||||
// Read with preloading
|
||||
const posts = await client.read('posts', {
|
||||
schema: 'public',
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
},
|
||||
{
|
||||
relation: 'comments',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'approved' }
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
#### `create<T>(entity: string, data: any, options?): Promise<T>`
|
||||
Create a new record.
|
||||
|
||||
```typescript
|
||||
const newUser = await client.create('users', {
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
status: 'active'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `update<T>(entity: string, id: string, data: any, options?): Promise<T>`
|
||||
Update an existing record.
|
||||
|
||||
```typescript
|
||||
const updatedUser = await client.update('users', '123', {
|
||||
name: 'John Updated',
|
||||
email: 'john.new@example.com'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `delete(entity: string, id: string, options?): Promise<void>`
|
||||
Delete a record.
|
||||
|
||||
```typescript
|
||||
await client.delete('users', '123', {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `meta<T>(entity: string, options?): Promise<T>`
|
||||
Get metadata for an entity.
|
||||
|
||||
```typescript
|
||||
const metadata = await client.meta('users', {
|
||||
schema: 'public'
|
||||
});
|
||||
console.log('Columns:', metadata.columns);
|
||||
console.log('Primary key:', metadata.primary_key);
|
||||
```
|
||||
|
||||
### Subscriptions
|
||||
|
||||
#### `subscribe(entity: string, callback: Function, options?): Promise<string>`
|
||||
Subscribe to entity changes.
|
||||
|
||||
```typescript
|
||||
const subscriptionId = await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
console.log('Operation:', notification.operation); // 'create', 'update', or 'delete'
|
||||
console.log('Data:', notification.data);
|
||||
console.log('Timestamp:', notification.timestamp);
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
#### `unsubscribe(subscriptionId: string): Promise<void>`
|
||||
Unsubscribe from entity changes.
|
||||
|
||||
```typescript
|
||||
await client.unsubscribe(subscriptionId);
|
||||
```
|
||||
|
||||
#### `getSubscriptions(): Subscription[]`
|
||||
Get list of active subscriptions.
|
||||
|
||||
```typescript
|
||||
const subscriptions = client.getSubscriptions();
|
||||
console.log('Active subscriptions:', subscriptions.length);
|
||||
```
|
||||
|
||||
### Event Handling
|
||||
|
||||
#### `on(event: string, callback: Function): void`
|
||||
Add event listener.
|
||||
|
||||
```typescript
|
||||
// Connection events
|
||||
client.on('connect', () => {
|
||||
console.log('Connected!');
|
||||
});
|
||||
|
||||
client.on('disconnect', (event) => {
|
||||
console.log('Disconnected:', event.code, event.reason);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
// State changes
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
});
|
||||
|
||||
// All messages
|
||||
client.on('message', (message) => {
|
||||
console.log('Message:', message);
|
||||
});
|
||||
```
|
||||
|
||||
#### `off(event: string): void`
|
||||
Remove event listener.
|
||||
|
||||
```typescript
|
||||
client.off('connect');
|
||||
```
|
||||
|
||||
## Filter Operators
|
||||
|
||||
- `eq` - Equal (=)
|
||||
- `neq` - Not Equal (!=)
|
||||
- `gt` - Greater Than (>)
|
||||
- `gte` - Greater Than or Equal (>=)
|
||||
- `lt` - Less Than (<)
|
||||
- `lte` - Less Than or Equal (<=)
|
||||
- `like` - LIKE (case-sensitive)
|
||||
- `ilike` - ILIKE (case-insensitive)
|
||||
- `in` - IN (array of values)
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic CRUD
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Create
|
||||
const user = await client.create('users', {
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com'
|
||||
});
|
||||
|
||||
// Read
|
||||
const users = await client.read('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
// Update
|
||||
await client.update('users', user.id, { name: 'Alice Updated' });
|
||||
|
||||
// Delete
|
||||
await client.delete('users', user.id);
|
||||
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
### Real-Time Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to all user changes
|
||||
const subId = await client.subscribe('users', (notification) => {
|
||||
switch (notification.operation) {
|
||||
case 'create':
|
||||
console.log('New user:', notification.data);
|
||||
break;
|
||||
case 'update':
|
||||
console.log('User updated:', notification.data);
|
||||
break;
|
||||
case 'delete':
|
||||
console.log('User deleted:', notification.data);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Later: unsubscribe
|
||||
await client.unsubscribe(subId);
|
||||
```
|
||||
|
||||
### React Integration
|
||||
|
||||
```typescript
|
||||
import { useEffect, useState } from 'react';
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
function useWebSocket(url: string) {
|
||||
const [client] = useState(() => new WebSocketClient({ url }));
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
client.on('connect', () => setIsConnected(true));
|
||||
client.on('disconnect', () => setIsConnected(false));
|
||||
client.connect();
|
||||
|
||||
return () => client.disconnect();
|
||||
}, [client]);
|
||||
|
||||
return { client, isConnected };
|
||||
}
|
||||
|
||||
function UsersComponent() {
|
||||
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
|
||||
const [users, setUsers] = useState([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return;
|
||||
|
||||
const loadUsers = async () => {
|
||||
// Subscribe to changes
|
||||
await client.subscribe('users', (notification) => {
|
||||
if (notification.operation === 'create') {
|
||||
setUsers(prev => [...prev, notification.data]);
|
||||
} else if (notification.operation === 'update') {
|
||||
setUsers(prev => prev.map(u =>
|
||||
u.id === notification.data.id ? notification.data : u
|
||||
));
|
||||
} else if (notification.operation === 'delete') {
|
||||
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
|
||||
}
|
||||
});
|
||||
|
||||
// Load initial data
|
||||
const data = await client.read('users');
|
||||
setUsers(data);
|
||||
};
|
||||
|
||||
loadUsers();
|
||||
}, [client, isConnected]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Users {isConnected ? '🟢' : '🔴'}</h2>
|
||||
{users.map(user => (
|
||||
<div key={user.id}>{user.name}</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### TypeScript with Typed Models
|
||||
|
||||
```typescript
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
status: 'active' | 'inactive';
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: number;
|
||||
title: string;
|
||||
content: string;
|
||||
user_id: number;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Type-safe operations
|
||||
const users = await client.read<User[]>('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
const newUser = await client.create<User>('users', {
|
||||
name: 'Bob',
|
||||
email: 'bob@example.com',
|
||||
status: 'active'
|
||||
});
|
||||
|
||||
// Type-safe subscriptions
|
||||
await client.subscribe(
|
||||
'posts',
|
||||
(notification) => {
|
||||
const post = notification.data as Post;
|
||||
console.log('Post:', post.title);
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
maxReconnectAttempts: 5
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Connection error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
if (state === 'reconnecting') {
|
||||
console.log('Attempting to reconnect...');
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
|
||||
try {
|
||||
const user = await client.read('users', { record_id: '999' });
|
||||
} catch (error) {
|
||||
console.error('Record not found:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
await client.create('users', { /* invalid data */ });
|
||||
} catch (error) {
|
||||
console.error('Validation failed:', error);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Connection failed:', error);
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to multiple entities
|
||||
const userSub = await client.subscribe('users', (n) => {
|
||||
console.log('[Users]', n.operation, n.data);
|
||||
});
|
||||
|
||||
const postSub = await client.subscribe('posts', (n) => {
|
||||
console.log('[Posts]', n.operation, n.data);
|
||||
}, {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'published' }]
|
||||
});
|
||||
|
||||
const commentSub = await client.subscribe('comments', (n) => {
|
||||
console.log('[Comments]', n.operation, n.data);
|
||||
});
|
||||
|
||||
// Check active subscriptions
|
||||
console.log('Active:', client.getSubscriptions().length);
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(userSub);
|
||||
await client.unsubscribe(postSub);
|
||||
await client.unsubscribe(commentSub);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always Clean Up**: Call `disconnect()` when done to close the connection properly
|
||||
2. **Use TypeScript**: Leverage type definitions for better type safety
|
||||
3. **Handle Errors**: Always wrap operations in try-catch blocks
|
||||
4. **Limit Subscriptions**: Don't create too many subscriptions per connection
|
||||
5. **Use Filters**: Apply filters to subscriptions to reduce unnecessary notifications
|
||||
6. **Connection State**: Check `isConnected()` before operations
|
||||
7. **Event Listeners**: Remove event listeners when no longer needed with `off()`
|
||||
8. **Reconnection**: Enable auto-reconnection for production apps
|
||||
|
||||
## Browser Support
|
||||
|
||||
- Chrome/Edge 88+
|
||||
- Firefox 85+
|
||||
- Safari 14+
|
||||
- Node.js 14.16+
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
1
resolvespec-js/dist/index.cjs
vendored
Normal file
1
resolvespec-js/dist/index.cjs
vendored
Normal file
File diff suppressed because one or more lines are too long
366
resolvespec-js/dist/index.d.ts
vendored
Normal file
366
resolvespec-js/dist/index.d.ts
vendored
Normal file
@@ -0,0 +1,366 @@
|
||||
export declare interface APIError {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
export declare interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build HTTP headers from Options, matching Go's restheadspec handler conventions.
|
||||
*
|
||||
* Header mapping:
|
||||
* - X-Select-Fields: comma-separated columns
|
||||
* - X-Not-Select-Fields: comma-separated omit_columns
|
||||
* - X-FieldFilter-{col}: exact match (eq)
|
||||
* - X-SearchOp-{operator}-{col}: AND filter
|
||||
* - X-SearchOr-{operator}-{col}: OR filter
|
||||
* - X-Sort: +col (asc), -col (desc)
|
||||
* - X-Limit, X-Offset: pagination
|
||||
* - X-Cursor-Forward, X-Cursor-Backward: cursor pagination
|
||||
* - X-Preload: RelationName:field1,field2 pipe-separated
|
||||
* - X-Fetch-RowNumber: row number fetch
|
||||
* - X-CQL-SEL-{col}: computed columns
|
||||
* - X-Custom-SQL-W: custom operators (AND)
|
||||
*/
|
||||
export declare function buildHeaders(options: Options): Record<string, string>;
|
||||
|
||||
export declare interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
|
||||
export declare interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export declare interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export declare type ConnectionState = 'connecting' | 'connected' | 'disconnecting' | 'disconnected' | 'reconnecting';
|
||||
|
||||
export declare interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode a header value that may be base64 encoded with ZIP_ or __ prefix.
|
||||
*/
|
||||
export declare function decodeHeaderValue(value: string): string;
|
||||
|
||||
/**
|
||||
* Encode a value with base64 and ZIP_ prefix for complex header values.
|
||||
*/
|
||||
export declare function encodeHeaderValue(value: string): string;
|
||||
|
||||
export declare interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
export declare function getHeaderSpecClient(config: ClientConfig): HeaderSpecClient;
|
||||
|
||||
export declare function getResolveSpecClient(config: ClientConfig): ResolveSpecClient;
|
||||
|
||||
export declare function getWebSocketClient(config: WebSocketClientConfig): WebSocketClient;
|
||||
|
||||
/**
|
||||
* HeaderSpec REST client.
|
||||
* Sends query options via HTTP headers instead of request body, matching the Go restheadspec handler.
|
||||
*
|
||||
* HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete
|
||||
*/
|
||||
export declare class HeaderSpecClient {
|
||||
private config;
|
||||
constructor(config: ClientConfig);
|
||||
private buildUrl;
|
||||
private baseHeaders;
|
||||
private fetchWithError;
|
||||
read<T = any>(schema: string, entity: string, id?: string, options?: Options): Promise<APIResponse<T>>;
|
||||
create<T = any>(schema: string, entity: string, data: any, options?: Options): Promise<APIResponse<T>>;
|
||||
update<T = any>(schema: string, entity: string, id: string, data: any, options?: Options): Promise<APIResponse<T>>;
|
||||
delete(schema: string, entity: string, id: string): Promise<APIResponse<void>>;
|
||||
}
|
||||
|
||||
export declare type MessageType = 'request' | 'response' | 'notification' | 'subscription' | 'error' | 'ping' | 'pong';
|
||||
|
||||
export declare interface Metadata {
|
||||
total: number;
|
||||
count: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
row_number?: number;
|
||||
}
|
||||
|
||||
export declare type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
|
||||
export declare type Operator = 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte' | 'like' | 'ilike' | 'in' | 'contains' | 'startswith' | 'endswith' | 'between' | 'between_inclusive' | 'is_null' | 'is_not_null';
|
||||
|
||||
export declare interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export declare interface Parameter {
|
||||
name: string;
|
||||
value: string;
|
||||
sequence?: number;
|
||||
}
|
||||
|
||||
export declare interface PreloadOption {
|
||||
relation: string;
|
||||
table_name?: string;
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
sort?: SortOption[];
|
||||
filters?: FilterOption[];
|
||||
where?: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
updatable?: boolean;
|
||||
computed_ql?: Record<string, string>;
|
||||
recursive?: boolean;
|
||||
primary_key?: string;
|
||||
related_key?: string;
|
||||
foreign_key?: string;
|
||||
recursive_child_key?: string;
|
||||
sql_joins?: string[];
|
||||
join_aliases?: string[];
|
||||
}
|
||||
|
||||
export declare interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: number | string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export declare class ResolveSpecClient {
|
||||
private config;
|
||||
constructor(config: ClientConfig);
|
||||
private buildUrl;
|
||||
private baseHeaders;
|
||||
private fetchWithError;
|
||||
getMetadata(schema: string, entity: string): Promise<APIResponse<TableMetadata>>;
|
||||
read<T = any>(schema: string, entity: string, id?: number | string | string[], options?: Options): Promise<APIResponse<T>>;
|
||||
create<T = any>(schema: string, entity: string, data: any | any[], options?: Options): Promise<APIResponse<T>>;
|
||||
update<T = any>(schema: string, entity: string, data: any | any[], id?: number | string | string[], options?: Options): Promise<APIResponse<T>>;
|
||||
delete(schema: string, entity: string, id: number | string): Promise<APIResponse<void>>;
|
||||
}
|
||||
|
||||
export declare type SortDirection = 'asc' | 'desc' | 'ASC' | 'DESC';
|
||||
|
||||
export declare interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export declare interface Subscription {
|
||||
id: string;
|
||||
entity: string;
|
||||
schema?: string;
|
||||
options?: WSOptions;
|
||||
callback?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
export declare interface SubscriptionOptions {
|
||||
filters?: FilterOption[];
|
||||
onNotification?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
export declare interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export declare class WebSocketClient {
|
||||
private ws;
|
||||
private config;
|
||||
private messageHandlers;
|
||||
private subscriptions;
|
||||
private eventListeners;
|
||||
private state;
|
||||
private reconnectAttempts;
|
||||
private reconnectTimer;
|
||||
private heartbeatTimer;
|
||||
private isManualClose;
|
||||
constructor(config: WebSocketClientConfig);
|
||||
connect(): Promise<void>;
|
||||
disconnect(): void;
|
||||
request<T = any>(operation: WSOperation, entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
}): Promise<T>;
|
||||
read<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
sort?: SortOption[];
|
||||
preload?: PreloadOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
}): Promise<T>;
|
||||
create<T = any>(entity: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
update<T = any>(entity: string, id: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
delete(entity: string, id: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<void>;
|
||||
meta<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
subscribe(entity: string, callback: (notification: WSNotificationMessage) => void, options?: {
|
||||
schema?: string;
|
||||
filters?: FilterOption[];
|
||||
}): Promise<string>;
|
||||
unsubscribe(subscriptionId: string): Promise<void>;
|
||||
getSubscriptions(): Subscription[];
|
||||
getState(): ConnectionState;
|
||||
isConnected(): boolean;
|
||||
on<K extends keyof WebSocketClientEvents>(event: K, callback: WebSocketClientEvents[K]): void;
|
||||
off<K extends keyof WebSocketClientEvents>(event: K): void;
|
||||
private handleMessage;
|
||||
private handleResponse;
|
||||
private handleNotification;
|
||||
private send;
|
||||
private startHeartbeat;
|
||||
private stopHeartbeat;
|
||||
private setState;
|
||||
private ensureConnected;
|
||||
private emit;
|
||||
private log;
|
||||
}
|
||||
|
||||
export declare interface WebSocketClientConfig {
|
||||
url: string;
|
||||
reconnect?: boolean;
|
||||
reconnectInterval?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
heartbeatInterval?: number;
|
||||
debug?: boolean;
|
||||
}
|
||||
|
||||
export declare interface WebSocketClientEvents {
|
||||
connect: () => void;
|
||||
disconnect: (event: CloseEvent) => void;
|
||||
error: (error: Error) => void;
|
||||
message: (message: WSMessage) => void;
|
||||
stateChange: (state: ConnectionState) => void;
|
||||
}
|
||||
|
||||
export declare interface WSErrorInfo {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: Record<string, any>;
|
||||
}
|
||||
|
||||
export declare interface WSMessage {
|
||||
id?: string;
|
||||
type: MessageType;
|
||||
operation?: WSOperation;
|
||||
schema?: string;
|
||||
entity?: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
subscription_id?: string;
|
||||
success?: boolean;
|
||||
error?: WSErrorInfo;
|
||||
metadata?: Record<string, any>;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
export declare interface WSNotificationMessage {
|
||||
type: 'notification';
|
||||
operation: WSOperation;
|
||||
subscription_id: string;
|
||||
schema?: string;
|
||||
entity: string;
|
||||
data: any;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export declare type WSOperation = 'read' | 'create' | 'update' | 'delete' | 'subscribe' | 'unsubscribe' | 'meta';
|
||||
|
||||
export declare interface WSOptions {
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
preload?: PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export declare interface WSRequestMessage {
|
||||
id: string;
|
||||
type: 'request';
|
||||
operation: WSOperation;
|
||||
schema?: string;
|
||||
entity: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
}
|
||||
|
||||
export declare interface WSResponseMessage {
|
||||
id: string;
|
||||
type: 'response';
|
||||
success: boolean;
|
||||
data?: any;
|
||||
error?: WSErrorInfo;
|
||||
metadata?: Record<string, any>;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export declare interface WSSubscriptionMessage {
|
||||
id: string;
|
||||
type: 'subscription';
|
||||
operation: 'subscribe' | 'unsubscribe';
|
||||
schema?: string;
|
||||
entity: string;
|
||||
options?: WSOptions;
|
||||
subscription_id?: string;
|
||||
}
|
||||
|
||||
export { }
|
||||
469
resolvespec-js/dist/index.js
vendored
Normal file
469
resolvespec-js/dist/index.js
vendored
Normal file
@@ -0,0 +1,469 @@
|
||||
import { v4 as l } from "uuid";
|
||||
const d = /* @__PURE__ */ new Map();
|
||||
function E(n) {
|
||||
const e = n.baseUrl;
|
||||
let t = d.get(e);
|
||||
return t || (t = new g(n), d.set(e, t)), t;
|
||||
}
|
||||
class g {
|
||||
constructor(e) {
|
||||
this.config = e;
|
||||
}
|
||||
buildUrl(e, t, s) {
|
||||
let r = `${this.config.baseUrl}/${e}/${t}`;
|
||||
return s && (r += `/${s}`), r;
|
||||
}
|
||||
baseHeaders() {
|
||||
const e = {
|
||||
"Content-Type": "application/json"
|
||||
};
|
||||
return this.config.token && (e.Authorization = `Bearer ${this.config.token}`), e;
|
||||
}
|
||||
async fetchWithError(e, t) {
|
||||
const s = await fetch(e, t), r = await s.json();
|
||||
if (!s.ok)
|
||||
throw new Error(r.error?.message || "An error occurred");
|
||||
return r;
|
||||
}
|
||||
async getMetadata(e, t) {
|
||||
const s = this.buildUrl(e, t);
|
||||
return this.fetchWithError(s, {
|
||||
method: "GET",
|
||||
headers: this.baseHeaders()
|
||||
});
|
||||
}
|
||||
async read(e, t, s, r) {
|
||||
const i = typeof s == "number" || typeof s == "string" ? String(s) : void 0, a = this.buildUrl(e, t, i), c = {
|
||||
operation: "read",
|
||||
id: Array.isArray(s) ? s : void 0,
|
||||
options: r
|
||||
};
|
||||
return this.fetchWithError(a, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(c)
|
||||
});
|
||||
}
|
||||
async create(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t), a = {
|
||||
operation: "create",
|
||||
data: s,
|
||||
options: r
|
||||
};
|
||||
return this.fetchWithError(i, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(a)
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r, i) {
|
||||
const a = typeof r == "number" || typeof r == "string" ? String(r) : void 0, c = this.buildUrl(e, t, a), o = {
|
||||
operation: "update",
|
||||
id: Array.isArray(r) ? r : void 0,
|
||||
data: s,
|
||||
options: i
|
||||
};
|
||||
return this.fetchWithError(c, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(o)
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
const r = this.buildUrl(e, t, String(s)), i = {
|
||||
operation: "delete"
|
||||
};
|
||||
return this.fetchWithError(r, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(i)
|
||||
});
|
||||
}
|
||||
}
|
||||
const f = /* @__PURE__ */ new Map();
|
||||
function _(n) {
|
||||
const e = n.url;
|
||||
let t = f.get(e);
|
||||
return t || (t = new p(n), f.set(e, t)), t;
|
||||
}
|
||||
class p {
|
||||
constructor(e) {
|
||||
this.ws = null, this.messageHandlers = /* @__PURE__ */ new Map(), this.subscriptions = /* @__PURE__ */ new Map(), this.eventListeners = {}, this.state = "disconnected", this.reconnectAttempts = 0, this.reconnectTimer = null, this.heartbeatTimer = null, this.isManualClose = !1, this.config = {
|
||||
url: e.url,
|
||||
reconnect: e.reconnect ?? !0,
|
||||
reconnectInterval: e.reconnectInterval ?? 3e3,
|
||||
maxReconnectAttempts: e.maxReconnectAttempts ?? 10,
|
||||
heartbeatInterval: e.heartbeatInterval ?? 3e4,
|
||||
debug: e.debug ?? !1
|
||||
};
|
||||
}
|
||||
async connect() {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.log("Already connected");
|
||||
return;
|
||||
}
|
||||
return this.isManualClose = !1, this.setState("connecting"), new Promise((e, t) => {
|
||||
try {
|
||||
this.ws = new WebSocket(this.config.url), this.ws.onopen = () => {
|
||||
this.log("Connected to WebSocket server"), this.setState("connected"), this.reconnectAttempts = 0, this.startHeartbeat(), this.emit("connect"), e();
|
||||
}, this.ws.onmessage = (s) => {
|
||||
this.handleMessage(s.data);
|
||||
}, this.ws.onerror = (s) => {
|
||||
this.log("WebSocket error:", s);
|
||||
const r = new Error("WebSocket connection error");
|
||||
this.emit("error", r), t(r);
|
||||
}, this.ws.onclose = (s) => {
|
||||
this.log("WebSocket closed:", s.code, s.reason), this.stopHeartbeat(), this.setState("disconnected"), this.emit("disconnect", s), this.config.reconnect && !this.isManualClose && this.reconnectAttempts < this.config.maxReconnectAttempts && (this.reconnectAttempts++, this.log(`Reconnection attempt ${this.reconnectAttempts}/${this.config.maxReconnectAttempts}`), this.setState("reconnecting"), this.reconnectTimer = setTimeout(() => {
|
||||
this.connect().catch((r) => {
|
||||
this.log("Reconnection failed:", r);
|
||||
});
|
||||
}, this.config.reconnectInterval));
|
||||
};
|
||||
} catch (s) {
|
||||
t(s);
|
||||
}
|
||||
});
|
||||
}
|
||||
disconnect() {
|
||||
this.isManualClose = !0, this.reconnectTimer && (clearTimeout(this.reconnectTimer), this.reconnectTimer = null), this.stopHeartbeat(), this.ws && (this.setState("disconnecting"), this.ws.close(), this.ws = null), this.setState("disconnected"), this.messageHandlers.clear();
|
||||
}
|
||||
async request(e, t, s) {
|
||||
this.ensureConnected();
|
||||
const r = l(), i = {
|
||||
id: r,
|
||||
type: "request",
|
||||
operation: e,
|
||||
entity: t,
|
||||
schema: s?.schema,
|
||||
record_id: s?.record_id,
|
||||
data: s?.data,
|
||||
options: s?.options
|
||||
};
|
||||
return new Promise((a, c) => {
|
||||
this.messageHandlers.set(r, (o) => {
|
||||
o.success ? a(o.data) : c(new Error(o.error?.message || "Request failed"));
|
||||
}), this.send(i), setTimeout(() => {
|
||||
this.messageHandlers.has(r) && (this.messageHandlers.delete(r), c(new Error("Request timeout")));
|
||||
}, 3e4);
|
||||
});
|
||||
}
|
||||
async read(e, t) {
|
||||
return this.request("read", e, {
|
||||
schema: t?.schema,
|
||||
record_id: t?.record_id,
|
||||
options: {
|
||||
filters: t?.filters,
|
||||
columns: t?.columns,
|
||||
sort: t?.sort,
|
||||
preload: t?.preload,
|
||||
limit: t?.limit,
|
||||
offset: t?.offset
|
||||
}
|
||||
});
|
||||
}
|
||||
async create(e, t, s) {
|
||||
return this.request("create", e, {
|
||||
schema: s?.schema,
|
||||
data: t
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r) {
|
||||
return this.request("update", e, {
|
||||
schema: r?.schema,
|
||||
record_id: t,
|
||||
data: s
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
await this.request("delete", e, {
|
||||
schema: s?.schema,
|
||||
record_id: t
|
||||
});
|
||||
}
|
||||
async meta(e, t) {
|
||||
return this.request("meta", e, {
|
||||
schema: t?.schema
|
||||
});
|
||||
}
|
||||
async subscribe(e, t, s) {
|
||||
this.ensureConnected();
|
||||
const r = l(), i = {
|
||||
id: r,
|
||||
type: "subscription",
|
||||
operation: "subscribe",
|
||||
entity: e,
|
||||
schema: s?.schema,
|
||||
options: {
|
||||
filters: s?.filters
|
||||
}
|
||||
};
|
||||
return new Promise((a, c) => {
|
||||
this.messageHandlers.set(r, (o) => {
|
||||
if (o.success && o.data?.subscription_id) {
|
||||
const h = o.data.subscription_id;
|
||||
this.subscriptions.set(h, {
|
||||
id: h,
|
||||
entity: e,
|
||||
schema: s?.schema,
|
||||
options: { filters: s?.filters },
|
||||
callback: t
|
||||
}), this.log(`Subscribed to ${e} with ID: ${h}`), a(h);
|
||||
} else
|
||||
c(new Error(o.error?.message || "Subscription failed"));
|
||||
}), this.send(i), setTimeout(() => {
|
||||
this.messageHandlers.has(r) && (this.messageHandlers.delete(r), c(new Error("Subscription timeout")));
|
||||
}, 1e4);
|
||||
});
|
||||
}
|
||||
async unsubscribe(e) {
|
||||
this.ensureConnected();
|
||||
const t = l(), s = {
|
||||
id: t,
|
||||
type: "subscription",
|
||||
operation: "unsubscribe",
|
||||
subscription_id: e
|
||||
};
|
||||
return new Promise((r, i) => {
|
||||
this.messageHandlers.set(t, (a) => {
|
||||
a.success ? (this.subscriptions.delete(e), this.log(`Unsubscribed from ${e}`), r()) : i(new Error(a.error?.message || "Unsubscribe failed"));
|
||||
}), this.send(s), setTimeout(() => {
|
||||
this.messageHandlers.has(t) && (this.messageHandlers.delete(t), i(new Error("Unsubscribe timeout")));
|
||||
}, 1e4);
|
||||
});
|
||||
}
|
||||
getSubscriptions() {
|
||||
return Array.from(this.subscriptions.values());
|
||||
}
|
||||
getState() {
|
||||
return this.state;
|
||||
}
|
||||
isConnected() {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
on(e, t) {
|
||||
this.eventListeners[e] = t;
|
||||
}
|
||||
off(e) {
|
||||
delete this.eventListeners[e];
|
||||
}
|
||||
// Private methods
|
||||
handleMessage(e) {
|
||||
try {
|
||||
const t = JSON.parse(e);
|
||||
switch (this.log("Received message:", t), this.emit("message", t), t.type) {
|
||||
case "response":
|
||||
this.handleResponse(t);
|
||||
break;
|
||||
case "notification":
|
||||
this.handleNotification(t);
|
||||
break;
|
||||
case "pong":
|
||||
break;
|
||||
default:
|
||||
this.log("Unknown message type:", t.type);
|
||||
}
|
||||
} catch (t) {
|
||||
this.log("Error parsing message:", t);
|
||||
}
|
||||
}
|
||||
handleResponse(e) {
|
||||
const t = this.messageHandlers.get(e.id);
|
||||
t && (t(e), this.messageHandlers.delete(e.id));
|
||||
}
|
||||
handleNotification(e) {
|
||||
const t = this.subscriptions.get(e.subscription_id);
|
||||
t?.callback && t.callback(e);
|
||||
}
|
||||
send(e) {
|
||||
if (!this.ws || this.ws.readyState !== WebSocket.OPEN)
|
||||
throw new Error("WebSocket is not connected");
|
||||
const t = JSON.stringify(e);
|
||||
this.log("Sending message:", e), this.ws.send(t);
|
||||
}
|
||||
startHeartbeat() {
|
||||
this.heartbeatTimer || (this.heartbeatTimer = setInterval(() => {
|
||||
if (this.isConnected()) {
|
||||
const e = {
|
||||
id: l(),
|
||||
type: "ping"
|
||||
};
|
||||
this.send(e);
|
||||
}
|
||||
}, this.config.heartbeatInterval));
|
||||
}
|
||||
stopHeartbeat() {
|
||||
this.heartbeatTimer && (clearInterval(this.heartbeatTimer), this.heartbeatTimer = null);
|
||||
}
|
||||
setState(e) {
|
||||
this.state !== e && (this.state = e, this.emit("stateChange", e));
|
||||
}
|
||||
ensureConnected() {
|
||||
if (!this.isConnected())
|
||||
throw new Error("WebSocket is not connected. Call connect() first.");
|
||||
}
|
||||
emit(e, ...t) {
|
||||
const s = this.eventListeners[e];
|
||||
s && s(...t);
|
||||
}
|
||||
log(...e) {
|
||||
this.config.debug && console.log("[WebSocketClient]", ...e);
|
||||
}
|
||||
}
|
||||
function v(n) {
|
||||
return typeof btoa == "function" ? "ZIP_" + btoa(n) : "ZIP_" + Buffer.from(n, "utf-8").toString("base64");
|
||||
}
|
||||
function w(n) {
|
||||
let e = n;
|
||||
return e.startsWith("ZIP_") ? (e = e.slice(4).replace(/[\n\r ]/g, ""), e = m(e)) : e.startsWith("__") && (e = e.slice(2).replace(/[\n\r ]/g, ""), e = m(e)), (e.startsWith("ZIP_") || e.startsWith("__")) && (e = w(e)), e;
|
||||
}
|
||||
function m(n) {
|
||||
return typeof atob == "function" ? atob(n) : Buffer.from(n, "base64").toString("utf-8");
|
||||
}
|
||||
function u(n) {
|
||||
const e = {};
|
||||
if (n.columns?.length && (e["X-Select-Fields"] = n.columns.join(",")), n.omit_columns?.length && (e["X-Not-Select-Fields"] = n.omit_columns.join(",")), n.filters?.length)
|
||||
for (const t of n.filters) {
|
||||
const s = t.logic_operator ?? "AND", r = y(t.operator), i = S(t);
|
||||
t.operator === "eq" && s === "AND" ? e[`X-FieldFilter-${t.column}`] = i : s === "OR" ? e[`X-SearchOr-${r}-${t.column}`] = i : e[`X-SearchOp-${r}-${t.column}`] = i;
|
||||
}
|
||||
if (n.sort?.length) {
|
||||
const t = n.sort.map((s) => s.direction.toUpperCase() === "DESC" ? `-${s.column}` : `+${s.column}`);
|
||||
e["X-Sort"] = t.join(",");
|
||||
}
|
||||
if (n.limit !== void 0 && (e["X-Limit"] = String(n.limit)), n.offset !== void 0 && (e["X-Offset"] = String(n.offset)), n.cursor_forward && (e["X-Cursor-Forward"] = n.cursor_forward), n.cursor_backward && (e["X-Cursor-Backward"] = n.cursor_backward), n.preload?.length) {
|
||||
const t = n.preload.map((s) => s.columns?.length ? `${s.relation}:${s.columns.join(",")}` : s.relation);
|
||||
e["X-Preload"] = t.join("|");
|
||||
}
|
||||
if (n.fetch_row_number && (e["X-Fetch-RowNumber"] = n.fetch_row_number), n.computedColumns?.length)
|
||||
for (const t of n.computedColumns)
|
||||
e[`X-CQL-SEL-${t.name}`] = t.expression;
|
||||
if (n.customOperators?.length) {
|
||||
const t = n.customOperators.map(
|
||||
(s) => s.sql
|
||||
);
|
||||
e["X-Custom-SQL-W"] = t.join(" AND ");
|
||||
}
|
||||
return e;
|
||||
}
|
||||
function y(n) {
|
||||
switch (n) {
|
||||
case "eq":
|
||||
return "equals";
|
||||
case "neq":
|
||||
return "notequals";
|
||||
case "gt":
|
||||
return "greaterthan";
|
||||
case "gte":
|
||||
return "greaterthanorequal";
|
||||
case "lt":
|
||||
return "lessthan";
|
||||
case "lte":
|
||||
return "lessthanorequal";
|
||||
case "like":
|
||||
case "ilike":
|
||||
case "contains":
|
||||
return "contains";
|
||||
case "startswith":
|
||||
return "beginswith";
|
||||
case "endswith":
|
||||
return "endswith";
|
||||
case "in":
|
||||
return "in";
|
||||
case "between":
|
||||
return "between";
|
||||
case "between_inclusive":
|
||||
return "betweeninclusive";
|
||||
case "is_null":
|
||||
return "empty";
|
||||
case "is_not_null":
|
||||
return "notempty";
|
||||
default:
|
||||
return n;
|
||||
}
|
||||
}
|
||||
function S(n) {
|
||||
return n.value === null || n.value === void 0 ? "" : Array.isArray(n.value) ? n.value.join(",") : String(n.value);
|
||||
}
|
||||
const b = /* @__PURE__ */ new Map();
|
||||
function C(n) {
|
||||
const e = n.baseUrl;
|
||||
let t = b.get(e);
|
||||
return t || (t = new H(n), b.set(e, t)), t;
|
||||
}
|
||||
class H {
|
||||
constructor(e) {
|
||||
this.config = e;
|
||||
}
|
||||
buildUrl(e, t, s) {
|
||||
let r = `${this.config.baseUrl}/${e}/${t}`;
|
||||
return s && (r += `/${s}`), r;
|
||||
}
|
||||
baseHeaders() {
|
||||
const e = {
|
||||
"Content-Type": "application/json"
|
||||
};
|
||||
return this.config.token && (e.Authorization = `Bearer ${this.config.token}`), e;
|
||||
}
|
||||
async fetchWithError(e, t) {
|
||||
const s = await fetch(e, t), r = await s.json();
|
||||
if (!s.ok)
|
||||
throw new Error(
|
||||
r.error?.message || `${s.statusText} (${s.status})`
|
||||
);
|
||||
return {
|
||||
data: r,
|
||||
success: !0,
|
||||
error: r.error ? r.error : void 0,
|
||||
metadata: {
|
||||
count: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
total: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
filtered: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
offset: s.headers.get("content-range") ? Number(
|
||||
s.headers.get("content-range")?.split("/")[0].split("-")[0]
|
||||
) : 0,
|
||||
limit: s.headers.get("x-limit") ? Number(s.headers.get("x-limit")) : 0
|
||||
}
|
||||
};
|
||||
}
|
||||
async read(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t, s), a = r ? u(r) : {};
|
||||
return this.fetchWithError(i, {
|
||||
method: "GET",
|
||||
headers: { ...this.baseHeaders(), ...a }
|
||||
});
|
||||
}
|
||||
async create(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t), a = r ? u(r) : {};
|
||||
return this.fetchWithError(i, {
|
||||
method: "POST",
|
||||
headers: { ...this.baseHeaders(), ...a },
|
||||
body: JSON.stringify(s)
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r, i) {
|
||||
const a = this.buildUrl(e, t, s), c = i ? u(i) : {};
|
||||
return this.fetchWithError(a, {
|
||||
method: "PUT",
|
||||
headers: { ...this.baseHeaders(), ...c },
|
||||
body: JSON.stringify(r)
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
const r = this.buildUrl(e, t, s);
|
||||
return this.fetchWithError(r, {
|
||||
method: "DELETE",
|
||||
headers: this.baseHeaders()
|
||||
});
|
||||
}
|
||||
}
|
||||
export {
|
||||
H as HeaderSpecClient,
|
||||
g as ResolveSpecClient,
|
||||
p as WebSocketClient,
|
||||
u as buildHeaders,
|
||||
w as decodeHeaderValue,
|
||||
v as encodeHeaderValue,
|
||||
C as getHeaderSpecClient,
|
||||
E as getResolveSpecClient,
|
||||
_ as getWebSocketClient
|
||||
};
|
||||
@@ -1,20 +1,23 @@
|
||||
{
|
||||
"name": "@warkypublic/resolvespec-js",
|
||||
"version": "1.0.0",
|
||||
"description": "Client side library for the ResolveSpec API",
|
||||
"version": "1.0.1",
|
||||
"description": "TypeScript client library for ResolveSpec REST, HeaderSpec, and WebSocket APIs",
|
||||
"type": "module",
|
||||
"main": "./src/index.ts",
|
||||
"module": "./src/index.ts",
|
||||
"types": "./src/index.ts",
|
||||
"main": "./dist/index.cjs",
|
||||
"module": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js",
|
||||
"require": "./dist/index.cjs"
|
||||
}
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public",
|
||||
"main": "./dist/index.js",
|
||||
"module": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts"
|
||||
"access": "public"
|
||||
},
|
||||
"files": [
|
||||
"dist",
|
||||
"bin",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
@@ -25,38 +28,33 @@
|
||||
"lint": "eslint src"
|
||||
},
|
||||
"keywords": [
|
||||
"string",
|
||||
"blob",
|
||||
"dependencies",
|
||||
"workspace",
|
||||
"package",
|
||||
"cli",
|
||||
"tools",
|
||||
"npm",
|
||||
"yarn",
|
||||
"pnpm"
|
||||
"resolvespec",
|
||||
"headerspec",
|
||||
"websocket",
|
||||
"rest-client",
|
||||
"typescript",
|
||||
"api-client"
|
||||
],
|
||||
"author": "Hein (Warkanum) Puth",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"semver": "^7.6.3",
|
||||
"uuid": "^11.0.3"
|
||||
"uuid": "^13.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@changesets/cli": "^2.27.10",
|
||||
"@eslint/js": "^9.16.0",
|
||||
"@types/jsdom": "^21.1.7",
|
||||
"eslint": "^9.16.0",
|
||||
"globals": "^15.13.0",
|
||||
"jsdom": "^25.0.1",
|
||||
"typescript": "^5.7.2",
|
||||
"typescript-eslint": "^8.17.0",
|
||||
"vite": "^6.0.2",
|
||||
"vite-plugin-dts": "^4.3.0",
|
||||
"vitest": "^2.1.8"
|
||||
"@changesets/cli": "^2.29.8",
|
||||
"@eslint/js": "^10.0.1",
|
||||
"@types/jsdom": "^27.0.0",
|
||||
"eslint": "^10.0.0",
|
||||
"globals": "^17.3.0",
|
||||
"jsdom": "^28.1.0",
|
||||
"typescript": "^5.9.3",
|
||||
"typescript-eslint": "^8.55.0",
|
||||
"vite": "^7.3.1",
|
||||
"vite-plugin-dts": "^4.5.4",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.16"
|
||||
"node": ">=18"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
|
||||
3376
resolvespec-js/pnpm-lock.yaml
generated
Normal file
3376
resolvespec-js/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
143
resolvespec-js/src/__tests__/common.test.ts
Normal file
143
resolvespec-js/src/__tests__/common.test.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import type {
|
||||
Options,
|
||||
FilterOption,
|
||||
SortOption,
|
||||
PreloadOption,
|
||||
RequestBody,
|
||||
APIResponse,
|
||||
Metadata,
|
||||
APIError,
|
||||
Parameter,
|
||||
ComputedColumn,
|
||||
CustomOperator,
|
||||
} from '../common/types';
|
||||
|
||||
describe('Common Types', () => {
|
||||
it('should construct a valid FilterOption with logic_operator', () => {
|
||||
const filter: FilterOption = {
|
||||
column: 'name',
|
||||
operator: 'eq',
|
||||
value: 'test',
|
||||
logic_operator: 'OR',
|
||||
};
|
||||
expect(filter.logic_operator).toBe('OR');
|
||||
expect(filter.operator).toBe('eq');
|
||||
});
|
||||
|
||||
it('should construct Options with all new fields', () => {
|
||||
const opts: Options = {
|
||||
columns: ['id', 'name'],
|
||||
omit_columns: ['secret'],
|
||||
filters: [{ column: 'age', operator: 'gte', value: 18 }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
cursor_forward: 'abc123',
|
||||
cursor_backward: 'xyz789',
|
||||
fetch_row_number: '42',
|
||||
parameters: [{ name: 'param1', value: 'val1', sequence: 1 }],
|
||||
computedColumns: [{ name: 'full_name', expression: "first || ' ' || last" }],
|
||||
customOperators: [{ name: 'custom', sql: "status = 'active'" }],
|
||||
preload: [{
|
||||
relation: 'Items',
|
||||
columns: ['id', 'title'],
|
||||
omit_columns: ['internal'],
|
||||
sort: [{ column: 'id', direction: 'ASC' }],
|
||||
recursive: true,
|
||||
primary_key: 'id',
|
||||
related_key: 'parent_id',
|
||||
sql_joins: ['LEFT JOIN other ON other.id = items.other_id'],
|
||||
join_aliases: ['other'],
|
||||
}],
|
||||
};
|
||||
expect(opts.omit_columns).toEqual(['secret']);
|
||||
expect(opts.cursor_forward).toBe('abc123');
|
||||
expect(opts.fetch_row_number).toBe('42');
|
||||
expect(opts.parameters![0].sequence).toBe(1);
|
||||
expect(opts.preload![0].recursive).toBe(true);
|
||||
});
|
||||
|
||||
it('should construct a RequestBody with numeric id', () => {
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
id: 42,
|
||||
options: { limit: 10 },
|
||||
};
|
||||
expect(body.id).toBe(42);
|
||||
});
|
||||
|
||||
it('should construct a RequestBody with string array id', () => {
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
id: ['1', '2', '3'],
|
||||
};
|
||||
expect(Array.isArray(body.id)).toBe(true);
|
||||
});
|
||||
|
||||
it('should construct Metadata with count and row_number', () => {
|
||||
const meta: Metadata = {
|
||||
total: 100,
|
||||
count: 10,
|
||||
filtered: 50,
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
row_number: 5,
|
||||
};
|
||||
expect(meta.count).toBe(10);
|
||||
expect(meta.row_number).toBe(5);
|
||||
});
|
||||
|
||||
it('should construct APIError with detail field', () => {
|
||||
const err: APIError = {
|
||||
code: 'not_found',
|
||||
message: 'Record not found',
|
||||
detail: 'The record with id 42 does not exist',
|
||||
};
|
||||
expect(err.detail).toBeDefined();
|
||||
});
|
||||
|
||||
it('should construct APIResponse with metadata', () => {
|
||||
const resp: APIResponse<string[]> = {
|
||||
success: true,
|
||||
data: ['a', 'b'],
|
||||
metadata: { total: 2, count: 2, filtered: 2, limit: 10, offset: 0 },
|
||||
};
|
||||
expect(resp.metadata?.count).toBe(2);
|
||||
});
|
||||
|
||||
it('should support all operator types', () => {
|
||||
const operators: FilterOption['operator'][] = [
|
||||
'eq', 'neq', 'gt', 'gte', 'lt', 'lte',
|
||||
'like', 'ilike', 'in',
|
||||
'contains', 'startswith', 'endswith',
|
||||
'between', 'between_inclusive',
|
||||
'is_null', 'is_not_null',
|
||||
];
|
||||
for (const op of operators) {
|
||||
const f: FilterOption = { column: 'x', operator: op, value: 'v' };
|
||||
expect(f.operator).toBe(op);
|
||||
}
|
||||
});
|
||||
|
||||
it('should support PreloadOption with computed_ql and where', () => {
|
||||
const preload: PreloadOption = {
|
||||
relation: 'Details',
|
||||
where: "status = 'active'",
|
||||
computed_ql: { cql1: 'SUM(amount)' },
|
||||
table_name: 'detail_table',
|
||||
updatable: true,
|
||||
foreign_key: 'detail_id',
|
||||
recursive_child_key: 'parent_detail_id',
|
||||
};
|
||||
expect(preload.computed_ql?.cql1).toBe('SUM(amount)');
|
||||
expect(preload.updatable).toBe(true);
|
||||
});
|
||||
|
||||
it('should support Parameter interface', () => {
|
||||
const p: Parameter = { name: 'key', value: 'val' };
|
||||
expect(p.name).toBe('key');
|
||||
const p2: Parameter = { name: 'key2', value: 'val2', sequence: 5 };
|
||||
expect(p2.sequence).toBe(5);
|
||||
});
|
||||
});
|
||||
239
resolvespec-js/src/__tests__/headerspec.test.ts
Normal file
239
resolvespec-js/src/__tests__/headerspec.test.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { buildHeaders, encodeHeaderValue, decodeHeaderValue, HeaderSpecClient, getHeaderSpecClient } from '../headerspec/client';
|
||||
import type { Options, ClientConfig, APIResponse } from '../common/types';
|
||||
|
||||
describe('buildHeaders', () => {
|
||||
it('should set X-Select-Fields for columns', () => {
|
||||
const h = buildHeaders({ columns: ['id', 'name', 'email'] });
|
||||
expect(h['X-Select-Fields']).toBe('id,name,email');
|
||||
});
|
||||
|
||||
it('should set X-Not-Select-Fields for omit_columns', () => {
|
||||
const h = buildHeaders({ omit_columns: ['secret', 'internal'] });
|
||||
expect(h['X-Not-Select-Fields']).toBe('secret,internal');
|
||||
});
|
||||
|
||||
it('should set X-FieldFilter for eq AND filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }],
|
||||
});
|
||||
expect(h['X-FieldFilter-status']).toBe('active');
|
||||
});
|
||||
|
||||
it('should set X-SearchOp for non-eq AND filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'age', operator: 'gte', value: 18 }],
|
||||
});
|
||||
expect(h['X-SearchOp-greaterthanorequal-age']).toBe('18');
|
||||
});
|
||||
|
||||
it('should set X-SearchOr for OR filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'name', operator: 'contains', value: 'test', logic_operator: 'OR' }],
|
||||
});
|
||||
expect(h['X-SearchOr-contains-name']).toBe('test');
|
||||
});
|
||||
|
||||
it('should set X-Sort with direction prefixes', () => {
|
||||
const h = buildHeaders({
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' },
|
||||
{ column: 'created_at', direction: 'DESC' },
|
||||
],
|
||||
});
|
||||
expect(h['X-Sort']).toBe('+name,-created_at');
|
||||
});
|
||||
|
||||
it('should set X-Limit and X-Offset', () => {
|
||||
const h = buildHeaders({ limit: 25, offset: 50 });
|
||||
expect(h['X-Limit']).toBe('25');
|
||||
expect(h['X-Offset']).toBe('50');
|
||||
});
|
||||
|
||||
it('should set cursor pagination headers', () => {
|
||||
const h = buildHeaders({ cursor_forward: 'abc', cursor_backward: 'xyz' });
|
||||
expect(h['X-Cursor-Forward']).toBe('abc');
|
||||
expect(h['X-Cursor-Backward']).toBe('xyz');
|
||||
});
|
||||
|
||||
it('should set X-Preload with pipe-separated relations', () => {
|
||||
const h = buildHeaders({
|
||||
preload: [
|
||||
{ relation: 'Items', columns: ['id', 'name'] },
|
||||
{ relation: 'Category' },
|
||||
],
|
||||
});
|
||||
expect(h['X-Preload']).toBe('Items:id,name|Category');
|
||||
});
|
||||
|
||||
it('should set X-Fetch-RowNumber', () => {
|
||||
const h = buildHeaders({ fetch_row_number: '42' });
|
||||
expect(h['X-Fetch-RowNumber']).toBe('42');
|
||||
});
|
||||
|
||||
it('should set X-CQL-SEL for computed columns', () => {
|
||||
const h = buildHeaders({
|
||||
computedColumns: [
|
||||
{ name: 'total', expression: 'price * qty' },
|
||||
],
|
||||
});
|
||||
expect(h['X-CQL-SEL-total']).toBe('price * qty');
|
||||
});
|
||||
|
||||
it('should set X-Custom-SQL-W for custom operators', () => {
|
||||
const h = buildHeaders({
|
||||
customOperators: [
|
||||
{ name: 'active', sql: "status = 'active'" },
|
||||
{ name: 'verified', sql: "verified = true" },
|
||||
],
|
||||
});
|
||||
expect(h['X-Custom-SQL-W']).toBe("status = 'active' AND verified = true");
|
||||
});
|
||||
|
||||
it('should return empty object for empty options', () => {
|
||||
const h = buildHeaders({});
|
||||
expect(Object.keys(h)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle between filter with array value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'price', operator: 'between', value: [10, 100] }],
|
||||
});
|
||||
expect(h['X-SearchOp-between-price']).toBe('10,100');
|
||||
});
|
||||
|
||||
it('should handle is_null filter with null value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'deleted_at', operator: 'is_null', value: null }],
|
||||
});
|
||||
expect(h['X-SearchOp-empty-deleted_at']).toBe('');
|
||||
});
|
||||
|
||||
it('should handle in filter with array value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'id', operator: 'in', value: [1, 2, 3] }],
|
||||
});
|
||||
expect(h['X-SearchOp-in-id']).toBe('1,2,3');
|
||||
});
|
||||
});
|
||||
|
||||
describe('encodeHeaderValue / decodeHeaderValue', () => {
|
||||
it('should round-trip encode/decode', () => {
|
||||
const original = 'some complex value with spaces & symbols!';
|
||||
const encoded = encodeHeaderValue(original);
|
||||
expect(encoded.startsWith('ZIP_')).toBe(true);
|
||||
const decoded = decodeHeaderValue(encoded);
|
||||
expect(decoded).toBe(original);
|
||||
});
|
||||
|
||||
it('should decode __ prefixed values', () => {
|
||||
const encoded = '__' + btoa('hello');
|
||||
expect(decodeHeaderValue(encoded)).toBe('hello');
|
||||
});
|
||||
|
||||
it('should return plain values as-is', () => {
|
||||
expect(decodeHeaderValue('plain')).toBe('plain');
|
||||
});
|
||||
});
|
||||
|
||||
describe('HeaderSpecClient', () => {
|
||||
const config: ClientConfig = { baseUrl: 'http://localhost:3000', token: 'tok' };
|
||||
|
||||
function mockFetch<T>(data: APIResponse<T>, ok = true) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok,
|
||||
json: () => Promise.resolve(data),
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('read() sends GET with headers from options', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: [{ id: 1 }] });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users');
|
||||
expect(opts.method).toBe('GET');
|
||||
expect(opts.headers['X-Select-Fields']).toBe('id,name');
|
||||
expect(opts.headers['X-Limit']).toBe('10');
|
||||
expect(opts.headers['Authorization']).toBe('Bearer tok');
|
||||
});
|
||||
|
||||
it('read() with id appends to URL', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: {} });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.read('public', 'users', '42');
|
||||
|
||||
const [url] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/42');
|
||||
});
|
||||
|
||||
it('create() sends POST with body and headers', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: { id: 1 } });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.create('public', 'users', { name: 'Test' });
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(opts.method).toBe('POST');
|
||||
expect(JSON.parse(opts.body)).toEqual({ name: 'Test' });
|
||||
});
|
||||
|
||||
it('update() sends PUT with id in URL', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: {} });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.update('public', 'users', '1', { name: 'Updated' }, {
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
});
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('PUT');
|
||||
expect(opts.headers['X-FieldFilter-active']).toBe('true');
|
||||
});
|
||||
|
||||
it('delete() sends DELETE', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: undefined as any });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.delete('public', 'users', '1');
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('DELETE');
|
||||
});
|
||||
|
||||
it('throws on non-ok response', async () => {
|
||||
globalThis.fetch = mockFetch(
|
||||
{ success: false, data: null as any, error: { code: 'err', message: 'fail' } },
|
||||
false
|
||||
);
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await expect(client.read('public', 'users')).rejects.toThrow('fail');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getHeaderSpecClient singleton', () => {
|
||||
it('returns same instance for same baseUrl', () => {
|
||||
const a = getHeaderSpecClient({ baseUrl: 'http://hs-singleton:3000' });
|
||||
const b = getHeaderSpecClient({ baseUrl: 'http://hs-singleton:3000' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different baseUrls', () => {
|
||||
const a = getHeaderSpecClient({ baseUrl: 'http://hs-singleton-a:3000' });
|
||||
const b = getHeaderSpecClient({ baseUrl: 'http://hs-singleton-b:3000' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
178
resolvespec-js/src/__tests__/resolvespec.test.ts
Normal file
178
resolvespec-js/src/__tests__/resolvespec.test.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ResolveSpecClient, getResolveSpecClient } from '../resolvespec/client';
|
||||
import type { ClientConfig, APIResponse } from '../common/types';
|
||||
|
||||
const config: ClientConfig = { baseUrl: 'http://localhost:3000', token: 'test-token' };
|
||||
|
||||
function mockFetchResponse<T>(data: APIResponse<T>, ok = true, status = 200) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok,
|
||||
status,
|
||||
json: () => Promise.resolve(data),
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('ResolveSpecClient', () => {
|
||||
it('read() sends POST with operation read', async () => {
|
||||
const response: APIResponse = { success: true, data: [{ id: 1 }] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.read('public', 'users', 1);
|
||||
expect(result.success).toBe(true);
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('POST');
|
||||
expect(opts.headers['Authorization']).toBe('Bearer test-token');
|
||||
|
||||
const body = JSON.parse(opts.body);
|
||||
expect(body.operation).toBe('read');
|
||||
});
|
||||
|
||||
it('read() with string array id puts id in body', async () => {
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.read('public', 'users', ['1', '2']);
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.id).toEqual(['1', '2']);
|
||||
});
|
||||
|
||||
it('read() passes options through', async () => {
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
omit_columns: ['secret'],
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
cursor_forward: 'cursor1',
|
||||
fetch_row_number: '5',
|
||||
});
|
||||
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.options.columns).toEqual(['id', 'name']);
|
||||
expect(body.options.omit_columns).toEqual(['secret']);
|
||||
expect(body.options.cursor_forward).toBe('cursor1');
|
||||
expect(body.options.fetch_row_number).toBe('5');
|
||||
});
|
||||
|
||||
it('create() sends POST with operation create and data', async () => {
|
||||
const response: APIResponse = { success: true, data: { id: 1, name: 'Test' } };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.create('public', 'users', { name: 'Test' });
|
||||
expect(result.data.name).toBe('Test');
|
||||
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.operation).toBe('create');
|
||||
expect(body.data.name).toBe('Test');
|
||||
});
|
||||
|
||||
it('update() with single id puts id in URL', async () => {
|
||||
const response: APIResponse = { success: true, data: { id: 1 } };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.update('public', 'users', { name: 'Updated' }, 1);
|
||||
const [url] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
});
|
||||
|
||||
it('update() with string array id puts id in body', async () => {
|
||||
const response: APIResponse = { success: true, data: {} };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.update('public', 'users', { active: false }, ['1', '2']);
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.id).toEqual(['1', '2']);
|
||||
});
|
||||
|
||||
it('delete() sends POST with operation delete', async () => {
|
||||
const response: APIResponse<void> = { success: true, data: undefined as any };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.delete('public', 'users', 1);
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
|
||||
const body = JSON.parse(opts.body);
|
||||
expect(body.operation).toBe('delete');
|
||||
});
|
||||
|
||||
it('getMetadata() sends GET request', async () => {
|
||||
const response: APIResponse = {
|
||||
success: true,
|
||||
data: { schema: 'public', table: 'users', columns: [], relations: [] },
|
||||
};
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.getMetadata('public', 'users');
|
||||
expect(result.data.table).toBe('users');
|
||||
|
||||
const opts = (globalThis.fetch as any).mock.calls[0][1];
|
||||
expect(opts.method).toBe('GET');
|
||||
});
|
||||
|
||||
it('throws on non-ok response', async () => {
|
||||
const errorResp = {
|
||||
success: false,
|
||||
data: null,
|
||||
error: { code: 'not_found', message: 'Not found' },
|
||||
};
|
||||
globalThis.fetch = mockFetchResponse(errorResp as any, false, 404);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await expect(client.read('public', 'users', 999)).rejects.toThrow('Not found');
|
||||
});
|
||||
|
||||
it('throws generic error when no error message', async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: () => Promise.resolve({ success: false, data: null }),
|
||||
});
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await expect(client.read('public', 'users')).rejects.toThrow('An error occurred');
|
||||
});
|
||||
|
||||
it('config without token omits Authorization header', async () => {
|
||||
const noAuthConfig: ClientConfig = { baseUrl: 'http://localhost:3000' };
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(noAuthConfig);
|
||||
await client.read('public', 'users');
|
||||
const opts = (globalThis.fetch as any).mock.calls[0][1];
|
||||
expect(opts.headers['Authorization']).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getResolveSpecClient singleton', () => {
|
||||
it('returns same instance for same baseUrl', () => {
|
||||
const a = getResolveSpecClient({ baseUrl: 'http://singleton-test:3000' });
|
||||
const b = getResolveSpecClient({ baseUrl: 'http://singleton-test:3000' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different baseUrls', () => {
|
||||
const a = getResolveSpecClient({ baseUrl: 'http://singleton-a:3000' });
|
||||
const b = getResolveSpecClient({ baseUrl: 'http://singleton-b:3000' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
336
resolvespec-js/src/__tests__/websocketspec.test.ts
Normal file
336
resolvespec-js/src/__tests__/websocketspec.test.ts
Normal file
@@ -0,0 +1,336 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { WebSocketClient, getWebSocketClient } from '../websocketspec/client';
|
||||
import type { WebSocketClientConfig } from '../websocketspec/types';
|
||||
|
||||
// Mock uuid
|
||||
vi.mock('uuid', () => ({
|
||||
v4: vi.fn(() => 'mock-uuid-1234'),
|
||||
}));
|
||||
|
||||
// Mock WebSocket
|
||||
class MockWebSocket {
|
||||
static OPEN = 1;
|
||||
static CLOSED = 3;
|
||||
|
||||
url: string;
|
||||
readyState = MockWebSocket.OPEN;
|
||||
onopen: ((ev: any) => void) | null = null;
|
||||
onclose: ((ev: any) => void) | null = null;
|
||||
onmessage: ((ev: any) => void) | null = null;
|
||||
onerror: ((ev: any) => void) | null = null;
|
||||
|
||||
private sentMessages: string[] = [];
|
||||
|
||||
constructor(url: string) {
|
||||
this.url = url;
|
||||
// Simulate async open
|
||||
setTimeout(() => {
|
||||
this.onopen?.({});
|
||||
}, 0);
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sentMessages.push(data);
|
||||
}
|
||||
|
||||
close() {
|
||||
this.readyState = MockWebSocket.CLOSED;
|
||||
this.onclose?.({ code: 1000, reason: 'Normal closure' } as any);
|
||||
}
|
||||
|
||||
getSentMessages(): any[] {
|
||||
return this.sentMessages.map((m) => JSON.parse(m));
|
||||
}
|
||||
|
||||
simulateMessage(data: any) {
|
||||
this.onmessage?.({ data: JSON.stringify(data) });
|
||||
}
|
||||
}
|
||||
|
||||
let mockWsInstance: MockWebSocket | null = null;
|
||||
|
||||
beforeEach(() => {
|
||||
mockWsInstance = null;
|
||||
(globalThis as any).WebSocket = class extends MockWebSocket {
|
||||
constructor(url: string) {
|
||||
super(url);
|
||||
mockWsInstance = this;
|
||||
}
|
||||
};
|
||||
(globalThis as any).WebSocket.OPEN = MockWebSocket.OPEN;
|
||||
(globalThis as any).WebSocket.CLOSED = MockWebSocket.CLOSED;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('WebSocketClient', () => {
|
||||
const wsConfig: WebSocketClientConfig = {
|
||||
url: 'ws://localhost:8080',
|
||||
reconnect: false,
|
||||
heartbeatInterval: 60000,
|
||||
};
|
||||
|
||||
it('should connect and set state to connected', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
expect(client.getState()).toBe('connected');
|
||||
expect(client.isConnected()).toBe(true);
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should disconnect and set state to disconnected', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
client.disconnect();
|
||||
expect(client.getState()).toBe('disconnected');
|
||||
expect(client.isConnected()).toBe(false);
|
||||
});
|
||||
|
||||
it('should send read request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const readPromise = client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
// Simulate server response
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent.length).toBe(1);
|
||||
expect(sent[0].operation).toBe('read');
|
||||
expect(sent[0].entity).toBe('users');
|
||||
expect(sent[0].options.filters[0].column).toBe('active');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: [{ id: 1 }],
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const result = await readPromise;
|
||||
expect(result).toEqual([{ id: 1 }]);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send create request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const createPromise = client.create('users', { name: 'Test' }, { schema: 'public' });
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('create');
|
||||
expect(sent[0].data.name).toBe('Test');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { id: 1, name: 'Test' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const result = await createPromise;
|
||||
expect(result.name).toBe('Test');
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send update request with record_id', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const updatePromise = client.update('users', '1', { name: 'Updated' });
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('update');
|
||||
expect(sent[0].record_id).toBe('1');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { id: 1, name: 'Updated' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await updatePromise;
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send delete request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const deletePromise = client.delete('users', '1');
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('delete');
|
||||
expect(sent[0].record_id).toBe('1');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await deletePromise;
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should reject on failed request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const readPromise = client.read('users');
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: false,
|
||||
error: { code: 'not_found', message: 'Not found' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await expect(readPromise).rejects.toThrow('Not found');
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle subscriptions', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const callback = vi.fn();
|
||||
const subPromise = client.subscribe('users', callback, {
|
||||
schema: 'public',
|
||||
});
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].type).toBe('subscription');
|
||||
expect(sent[0].operation).toBe('subscribe');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { subscription_id: 'sub-1' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const subId = await subPromise;
|
||||
expect(subId).toBe('sub-1');
|
||||
expect(client.getSubscriptions()).toHaveLength(1);
|
||||
|
||||
// Simulate notification
|
||||
mockWsInstance!.simulateMessage({
|
||||
type: 'notification',
|
||||
operation: 'create',
|
||||
subscription_id: 'sub-1',
|
||||
entity: 'users',
|
||||
data: { id: 2, name: 'New' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
expect(callback).toHaveBeenCalledTimes(1);
|
||||
expect(callback.mock.calls[0][0].data.id).toBe(2);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle unsubscribe', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
// Subscribe first
|
||||
const subPromise = client.subscribe('users', vi.fn());
|
||||
let sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { subscription_id: 'sub-1' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
await subPromise;
|
||||
|
||||
// Unsubscribe
|
||||
const unsubPromise = client.unsubscribe('sub-1');
|
||||
sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[sent.length - 1].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await unsubPromise;
|
||||
expect(client.getSubscriptions()).toHaveLength(0);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should emit events', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
const connectCb = vi.fn();
|
||||
const stateChangeCb = vi.fn();
|
||||
|
||||
client.on('connect', connectCb);
|
||||
client.on('stateChange', stateChangeCb);
|
||||
|
||||
await client.connect();
|
||||
|
||||
expect(connectCb).toHaveBeenCalledTimes(1);
|
||||
expect(stateChangeCb).toHaveBeenCalled();
|
||||
|
||||
client.off('connect');
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should reject when sending without connection', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await expect(client.read('users')).rejects.toThrow('WebSocket is not connected');
|
||||
});
|
||||
|
||||
it('should handle pong messages without error', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
// Should not throw
|
||||
mockWsInstance!.simulateMessage({ type: 'pong' });
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle malformed messages gracefully', async () => {
|
||||
const client = new WebSocketClient({ ...wsConfig, debug: false });
|
||||
await client.connect();
|
||||
|
||||
// Simulate non-JSON message
|
||||
mockWsInstance!.onmessage?.({ data: 'not-json' } as any);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getWebSocketClient singleton', () => {
|
||||
it('returns same instance for same url', () => {
|
||||
const a = getWebSocketClient({ url: 'ws://ws-singleton:8080' });
|
||||
const b = getWebSocketClient({ url: 'ws://ws-singleton:8080' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different urls', () => {
|
||||
const a = getWebSocketClient({ url: 'ws://ws-singleton-a:8080' });
|
||||
const b = getWebSocketClient({ url: 'ws://ws-singleton-b:8080' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
@@ -1,132 +0,0 @@
|
||||
import { ClientConfig, APIResponse, TableMetadata, Options, RequestBody } from "./types";
|
||||
|
||||
// Helper functions
|
||||
const getHeaders = (options?: Record<string,any>): HeadersInit => {
|
||||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (options?.token) {
|
||||
headers['Authorization'] = `Bearer ${options.token}`;
|
||||
}
|
||||
|
||||
return headers;
|
||||
};
|
||||
|
||||
const buildUrl = (config: ClientConfig, schema: string, entity: string, id?: string): string => {
|
||||
let url = `${config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
};
|
||||
|
||||
const fetchWithError = async <T>(url: string, options: RequestInit): Promise<APIResponse<T>> => {
|
||||
try {
|
||||
const response = await fetch(url, options);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error?.message || 'An error occurred');
|
||||
}
|
||||
|
||||
return data;
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
// API Functions
|
||||
export const getMetadata = async (
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string
|
||||
): Promise<APIResponse<TableMetadata>> => {
|
||||
const url = buildUrl(config, schema, entity);
|
||||
return fetchWithError<TableMetadata>(url, {
|
||||
method: 'GET',
|
||||
headers: getHeaders(config),
|
||||
});
|
||||
};
|
||||
|
||||
export const read = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: string,
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity, id);
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const create = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity);
|
||||
const body: RequestBody = {
|
||||
operation: 'create',
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const update = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
id?: string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity, typeof id === 'string' ? id : undefined);
|
||||
const body: RequestBody = {
|
||||
operation: 'update',
|
||||
id: typeof id === 'string' ? undefined : id,
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const deleteEntity = async (
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string
|
||||
): Promise<APIResponse<void>> => {
|
||||
const url = buildUrl(config, schema, entity, id);
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
};
|
||||
|
||||
return fetchWithError<void>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
1
resolvespec-js/src/common/index.ts
Normal file
1
resolvespec-js/src/common/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export * from './types';
|
||||
129
resolvespec-js/src/common/types.ts
Normal file
129
resolvespec-js/src/common/types.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
// Types aligned with Go pkg/common/types.go
|
||||
|
||||
export type Operator =
|
||||
| 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte'
|
||||
| 'like' | 'ilike' | 'in'
|
||||
| 'contains' | 'startswith' | 'endswith'
|
||||
| 'between' | 'between_inclusive'
|
||||
| 'is_null' | 'is_not_null';
|
||||
|
||||
export type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
export type SortDirection = 'asc' | 'desc' | 'ASC' | 'DESC';
|
||||
|
||||
export interface Parameter {
|
||||
name: string;
|
||||
value: string;
|
||||
sequence?: number;
|
||||
}
|
||||
|
||||
export interface PreloadOption {
|
||||
relation: string;
|
||||
table_name?: string;
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
sort?: SortOption[];
|
||||
filters?: FilterOption[];
|
||||
where?: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
updatable?: boolean;
|
||||
computed_ql?: Record<string, string>;
|
||||
recursive?: boolean;
|
||||
// Relationship keys
|
||||
primary_key?: string;
|
||||
related_key?: string;
|
||||
foreign_key?: string;
|
||||
recursive_child_key?: string;
|
||||
// Custom SQL JOINs
|
||||
sql_joins?: string[];
|
||||
join_aliases?: string[];
|
||||
}
|
||||
|
||||
export interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
export interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
export interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: number | string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export interface Metadata {
|
||||
total: number;
|
||||
count: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
row_number?: number;
|
||||
}
|
||||
|
||||
export interface APIError {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
export interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
|
||||
export interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
import { getMetadata, read, create, update, deleteEntity } from "./api";
|
||||
import { ClientConfig } from "./types";
|
||||
|
||||
// Usage Examples
|
||||
const config: ClientConfig = {
|
||||
baseUrl: 'http://api.example.com/v1',
|
||||
token: 'your-token-here'
|
||||
};
|
||||
|
||||
// Example usage
|
||||
const examples = async () => {
|
||||
// Get metadata
|
||||
const metadata = await getMetadata(config, 'test', 'employees');
|
||||
|
||||
|
||||
// Read with relations
|
||||
const employees = await read(config, 'test', 'employees', undefined, {
|
||||
preload: [
|
||||
{
|
||||
relation: 'department',
|
||||
columns: ['id', 'name']
|
||||
}
|
||||
],
|
||||
filters: [
|
||||
{
|
||||
column: 'status',
|
||||
operator: 'eq',
|
||||
value: 'active'
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Create single record
|
||||
const newEmployee = await create(config, 'test', 'employees', {
|
||||
first_name: 'John',
|
||||
last_name: 'Doe',
|
||||
email: 'john@example.com'
|
||||
});
|
||||
|
||||
// Bulk create
|
||||
const newEmployees = await create(config, 'test', 'employees', [
|
||||
{
|
||||
first_name: 'Jane',
|
||||
last_name: 'Smith',
|
||||
email: 'jane@example.com'
|
||||
},
|
||||
{
|
||||
first_name: 'Bob',
|
||||
last_name: 'Johnson',
|
||||
email: 'bob@example.com'
|
||||
}
|
||||
]);
|
||||
|
||||
// Update single record
|
||||
const updatedEmployee = await update(config, 'test', 'employees',
|
||||
{ status: 'inactive' },
|
||||
'emp123'
|
||||
);
|
||||
|
||||
// Bulk update
|
||||
const updatedEmployees = await update(config, 'test', 'employees',
|
||||
{ department_id: 'dept2' },
|
||||
['emp1', 'emp2', 'emp3']
|
||||
);
|
||||
|
||||
// Delete
|
||||
await deleteEntity(config, 'test', 'employees', 'emp123');
|
||||
};
|
||||
345
resolvespec-js/src/headerspec/client.ts
Normal file
345
resolvespec-js/src/headerspec/client.ts
Normal file
@@ -0,0 +1,345 @@
|
||||
import type {
|
||||
APIResponse,
|
||||
ClientConfig,
|
||||
CustomOperator,
|
||||
FilterOption,
|
||||
Options,
|
||||
PreloadOption,
|
||||
SortOption,
|
||||
} from "../common/types";
|
||||
|
||||
/**
|
||||
* Encode a value with base64 and ZIP_ prefix for complex header values.
|
||||
*/
|
||||
export function encodeHeaderValue(value: string): string {
|
||||
if (typeof btoa === "function") {
|
||||
return "ZIP_" + btoa(value);
|
||||
}
|
||||
return "ZIP_" + Buffer.from(value, "utf-8").toString("base64");
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode a header value that may be base64 encoded with ZIP_ or __ prefix.
|
||||
*/
|
||||
export function decodeHeaderValue(value: string): string {
|
||||
let code = value;
|
||||
|
||||
if (code.startsWith("ZIP_")) {
|
||||
code = code.slice(4).replace(/[\n\r ]/g, "");
|
||||
code = decodeBase64(code);
|
||||
} else if (code.startsWith("__")) {
|
||||
code = code.slice(2).replace(/[\n\r ]/g, "");
|
||||
code = decodeBase64(code);
|
||||
}
|
||||
|
||||
// Handle nested encoding
|
||||
if (code.startsWith("ZIP_") || code.startsWith("__")) {
|
||||
code = decodeHeaderValue(code);
|
||||
}
|
||||
|
||||
return code;
|
||||
}
|
||||
|
||||
function decodeBase64(str: string): string {
|
||||
if (typeof atob === "function") {
|
||||
return atob(str);
|
||||
}
|
||||
return Buffer.from(str, "base64").toString("utf-8");
|
||||
}
|
||||
|
||||
/**
|
||||
* Build HTTP headers from Options, matching Go's restheadspec handler conventions.
|
||||
*
|
||||
* Header mapping:
|
||||
* - X-Select-Fields: comma-separated columns
|
||||
* - X-Not-Select-Fields: comma-separated omit_columns
|
||||
* - X-FieldFilter-{col}: exact match (eq)
|
||||
* - X-SearchOp-{operator}-{col}: AND filter
|
||||
* - X-SearchOr-{operator}-{col}: OR filter
|
||||
* - X-Sort: +col (asc), -col (desc)
|
||||
* - X-Limit, X-Offset: pagination
|
||||
* - X-Cursor-Forward, X-Cursor-Backward: cursor pagination
|
||||
* - X-Preload: RelationName:field1,field2 pipe-separated
|
||||
* - X-Fetch-RowNumber: row number fetch
|
||||
* - X-CQL-SEL-{col}: computed columns
|
||||
* - X-Custom-SQL-W: custom operators (AND)
|
||||
*/
|
||||
export function buildHeaders(options: Options): Record<string, string> {
|
||||
const headers: Record<string, string> = {};
|
||||
|
||||
// Column selection
|
||||
if (options.columns?.length) {
|
||||
headers["X-Select-Fields"] = options.columns.join(",");
|
||||
}
|
||||
|
||||
if (options.omit_columns?.length) {
|
||||
headers["X-Not-Select-Fields"] = options.omit_columns.join(",");
|
||||
}
|
||||
|
||||
// Filters
|
||||
if (options.filters?.length) {
|
||||
for (const filter of options.filters) {
|
||||
const logicOp = filter.logic_operator ?? "AND";
|
||||
const op = mapOperatorToHeaderOp(filter.operator);
|
||||
const valueStr = formatFilterValue(filter);
|
||||
|
||||
if (filter.operator === "eq" && logicOp === "AND") {
|
||||
// Simple field filter shorthand
|
||||
headers[`X-FieldFilter-${filter.column}`] = valueStr;
|
||||
} else if (logicOp === "OR") {
|
||||
headers[`X-SearchOr-${op}-${filter.column}`] = valueStr;
|
||||
} else {
|
||||
headers[`X-SearchOp-${op}-${filter.column}`] = valueStr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort
|
||||
if (options.sort?.length) {
|
||||
const sortParts = options.sort.map((s: SortOption) => {
|
||||
const dir = s.direction.toUpperCase();
|
||||
return dir === "DESC" ? `-${s.column}` : `+${s.column}`;
|
||||
});
|
||||
headers["X-Sort"] = sortParts.join(",");
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if (options.limit !== undefined) {
|
||||
headers["X-Limit"] = String(options.limit);
|
||||
}
|
||||
if (options.offset !== undefined) {
|
||||
headers["X-Offset"] = String(options.offset);
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if (options.cursor_forward) {
|
||||
headers["X-Cursor-Forward"] = options.cursor_forward;
|
||||
}
|
||||
if (options.cursor_backward) {
|
||||
headers["X-Cursor-Backward"] = options.cursor_backward;
|
||||
}
|
||||
|
||||
// Preload
|
||||
if (options.preload?.length) {
|
||||
const parts = options.preload.map((p: PreloadOption) => {
|
||||
if (p.columns?.length) {
|
||||
return `${p.relation}:${p.columns.join(",")}`;
|
||||
}
|
||||
return p.relation;
|
||||
});
|
||||
headers["X-Preload"] = parts.join("|");
|
||||
}
|
||||
|
||||
// Fetch row number
|
||||
if (options.fetch_row_number) {
|
||||
headers["X-Fetch-RowNumber"] = options.fetch_row_number;
|
||||
}
|
||||
|
||||
// Computed columns
|
||||
if (options.computedColumns?.length) {
|
||||
for (const cc of options.computedColumns) {
|
||||
headers[`X-CQL-SEL-${cc.name}`] = cc.expression;
|
||||
}
|
||||
}
|
||||
|
||||
// Custom operators -> X-Custom-SQL-W
|
||||
if (options.customOperators?.length) {
|
||||
const sqlParts = options.customOperators.map(
|
||||
(co: CustomOperator) => co.sql,
|
||||
);
|
||||
headers["X-Custom-SQL-W"] = sqlParts.join(" AND ");
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
function mapOperatorToHeaderOp(operator: string): string {
|
||||
switch (operator) {
|
||||
case "eq":
|
||||
return "equals";
|
||||
case "neq":
|
||||
return "notequals";
|
||||
case "gt":
|
||||
return "greaterthan";
|
||||
case "gte":
|
||||
return "greaterthanorequal";
|
||||
case "lt":
|
||||
return "lessthan";
|
||||
case "lte":
|
||||
return "lessthanorequal";
|
||||
case "like":
|
||||
case "ilike":
|
||||
case "contains":
|
||||
return "contains";
|
||||
case "startswith":
|
||||
return "beginswith";
|
||||
case "endswith":
|
||||
return "endswith";
|
||||
case "in":
|
||||
return "in";
|
||||
case "between":
|
||||
return "between";
|
||||
case "between_inclusive":
|
||||
return "betweeninclusive";
|
||||
case "is_null":
|
||||
return "empty";
|
||||
case "is_not_null":
|
||||
return "notempty";
|
||||
default:
|
||||
return operator;
|
||||
}
|
||||
}
|
||||
|
||||
function formatFilterValue(filter: FilterOption): string {
|
||||
if (filter.value === null || filter.value === undefined) {
|
||||
return "";
|
||||
}
|
||||
if (Array.isArray(filter.value)) {
|
||||
return filter.value.join(",");
|
||||
}
|
||||
return String(filter.value);
|
||||
}
|
||||
|
||||
const instances = new Map<string, HeaderSpecClient>();
|
||||
|
||||
export function getHeaderSpecClient(config: ClientConfig): HeaderSpecClient {
|
||||
const key = config.baseUrl;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new HeaderSpecClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* HeaderSpec REST client.
|
||||
* Sends query options via HTTP headers instead of request body, matching the Go restheadspec handler.
|
||||
*
|
||||
* HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete
|
||||
*/
|
||||
export class HeaderSpecClient {
|
||||
private config: ClientConfig;
|
||||
|
||||
constructor(config: ClientConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
private buildUrl(schema: string, entity: string, id?: string): string {
|
||||
let url = `${this.config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
private baseHeaders(): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
if (this.config.token) {
|
||||
headers["Authorization"] = `Bearer ${this.config.token}`;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
private async fetchWithError<T>(
|
||||
url: string,
|
||||
init: RequestInit,
|
||||
): Promise<APIResponse<T>> {
|
||||
const response = await fetch(url, init);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
data.error?.message ||
|
||||
`${response.statusText} ` + `(${response.status})`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
data: data,
|
||||
success: true,
|
||||
error: data.error ? data.error : undefined,
|
||||
metadata: {
|
||||
count: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
total: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
filtered: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
offset: response.headers.get("content-range")
|
||||
? Number(
|
||||
response.headers
|
||||
.get("content-range")
|
||||
?.split("/")[0]
|
||||
.split("-")[0],
|
||||
)
|
||||
: 0,
|
||||
limit: response.headers.get("x-limit")
|
||||
? Number(response.headers.get("x-limit"))
|
||||
: 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async read<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: string,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "GET",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
});
|
||||
}
|
||||
|
||||
async create<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "POST",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
body: JSON.stringify(data),
|
||||
});
|
||||
}
|
||||
|
||||
async update<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string,
|
||||
data: any,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "PUT",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
body: JSON.stringify(data),
|
||||
});
|
||||
}
|
||||
|
||||
async delete(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string,
|
||||
): Promise<APIResponse<void>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
return this.fetchWithError<void>(url, {
|
||||
method: "DELETE",
|
||||
headers: this.baseHeaders(),
|
||||
});
|
||||
}
|
||||
}
|
||||
7
resolvespec-js/src/headerspec/index.ts
Normal file
7
resolvespec-js/src/headerspec/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export {
|
||||
HeaderSpecClient,
|
||||
getHeaderSpecClient,
|
||||
buildHeaders,
|
||||
encodeHeaderValue,
|
||||
decodeHeaderValue,
|
||||
} from './client';
|
||||
@@ -1,7 +1,11 @@
|
||||
// Types
|
||||
export * from './types';
|
||||
export * from './websocket-types';
|
||||
// Common types
|
||||
export * from './common';
|
||||
|
||||
// WebSocket Client
|
||||
export { WebSocketClient } from './websocket-client';
|
||||
export type { WebSocketClient as default } from './websocket-client';
|
||||
// REST client (ResolveSpec)
|
||||
export * from './resolvespec';
|
||||
|
||||
// WebSocket client
|
||||
export * from './websocketspec';
|
||||
|
||||
// HeaderSpec client
|
||||
export * from './headerspec';
|
||||
|
||||
141
resolvespec-js/src/resolvespec/client.ts
Normal file
141
resolvespec-js/src/resolvespec/client.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import type { ClientConfig, APIResponse, TableMetadata, Options, RequestBody } from '../common/types';
|
||||
|
||||
const instances = new Map<string, ResolveSpecClient>();
|
||||
|
||||
export function getResolveSpecClient(config: ClientConfig): ResolveSpecClient {
|
||||
const key = config.baseUrl;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new ResolveSpecClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
export class ResolveSpecClient {
|
||||
private config: ClientConfig;
|
||||
|
||||
constructor(config: ClientConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
private buildUrl(schema: string, entity: string, id?: string): string {
|
||||
let url = `${this.config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
private baseHeaders(): HeadersInit {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (this.config.token) {
|
||||
headers['Authorization'] = `Bearer ${this.config.token}`;
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private async fetchWithError<T>(url: string, options: RequestInit): Promise<APIResponse<T>> {
|
||||
const response = await fetch(url, options);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error?.message || 'An error occurred');
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
async getMetadata(schema: string, entity: string): Promise<APIResponse<TableMetadata>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
return this.fetchWithError<TableMetadata>(url, {
|
||||
method: 'GET',
|
||||
headers: this.baseHeaders(),
|
||||
});
|
||||
}
|
||||
|
||||
async read<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: number | string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const urlId = typeof id === 'number' || typeof id === 'string' ? String(id) : undefined;
|
||||
const url = this.buildUrl(schema, entity, urlId);
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
id: Array.isArray(id) ? id : undefined,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async create<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
const body: RequestBody = {
|
||||
operation: 'create',
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async update<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
id?: number | string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const urlId = typeof id === 'number' || typeof id === 'string' ? String(id) : undefined;
|
||||
const url = this.buildUrl(schema, entity, urlId);
|
||||
const body: RequestBody = {
|
||||
operation: 'update',
|
||||
id: Array.isArray(id) ? id : undefined,
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async delete(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: number | string
|
||||
): Promise<APIResponse<void>> {
|
||||
const url = this.buildUrl(schema, entity, String(id));
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
};
|
||||
|
||||
return this.fetchWithError<void>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
}
|
||||
1
resolvespec-js/src/resolvespec/index.ts
Normal file
1
resolvespec-js/src/resolvespec/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { ResolveSpecClient, getResolveSpecClient } from './client';
|
||||
@@ -1,86 +0,0 @@
|
||||
// Types
|
||||
export type Operator = 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte' | 'like' | 'ilike' | 'in';
|
||||
export type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
export type SortDirection = 'asc' | 'desc';
|
||||
|
||||
export interface PreloadOption {
|
||||
relation: string;
|
||||
columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
}
|
||||
|
||||
export interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator;
|
||||
value: any;
|
||||
}
|
||||
|
||||
export interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
export interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
}
|
||||
|
||||
export interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: {
|
||||
total: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
error?: {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
};
|
||||
}
|
||||
|
||||
export interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
@@ -1,427 +0,0 @@
|
||||
import { WebSocketClient } from './websocket-client';
|
||||
import type { WSNotificationMessage } from './websocket-types';
|
||||
|
||||
/**
|
||||
* Example 1: Basic Usage
|
||||
*/
|
||||
export async function basicUsageExample() {
|
||||
// Create client
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
debug: true
|
||||
});
|
||||
|
||||
// Connect
|
||||
await client.connect();
|
||||
|
||||
// Read users
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
limit: 10,
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' }
|
||||
]
|
||||
});
|
||||
|
||||
console.log('Users:', users);
|
||||
|
||||
// Create a user
|
||||
const newUser = await client.create('users', {
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
status: 'active'
|
||||
}, { schema: 'public' });
|
||||
|
||||
console.log('Created user:', newUser);
|
||||
|
||||
// Update user
|
||||
const updatedUser = await client.update('users', '123', {
|
||||
name: 'John Updated'
|
||||
}, { schema: 'public' });
|
||||
|
||||
console.log('Updated user:', updatedUser);
|
||||
|
||||
// Delete user
|
||||
await client.delete('users', '123', { schema: 'public' });
|
||||
|
||||
// Disconnect
|
||||
client.disconnect();
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 2: Real-time Subscriptions
|
||||
*/
|
||||
export async function subscriptionExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
debug: true
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to user changes
|
||||
const subscriptionId = await client.subscribe(
|
||||
'users',
|
||||
(notification: WSNotificationMessage) => {
|
||||
console.log('User changed:', notification.operation, notification.data);
|
||||
|
||||
switch (notification.operation) {
|
||||
case 'create':
|
||||
console.log('New user created:', notification.data);
|
||||
break;
|
||||
case 'update':
|
||||
console.log('User updated:', notification.data);
|
||||
break;
|
||||
case 'delete':
|
||||
console.log('User deleted:', notification.data);
|
||||
break;
|
||||
}
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
);
|
||||
|
||||
console.log('Subscribed with ID:', subscriptionId);
|
||||
|
||||
// Later: unsubscribe
|
||||
setTimeout(async () => {
|
||||
await client.unsubscribe(subscriptionId);
|
||||
console.log('Unsubscribed');
|
||||
client.disconnect();
|
||||
}, 60000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 3: Event Handling
|
||||
*/
|
||||
export async function eventHandlingExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
// Listen to connection events
|
||||
client.on('connect', () => {
|
||||
console.log('Connected!');
|
||||
});
|
||||
|
||||
client.on('disconnect', (event) => {
|
||||
console.log('Disconnected:', event.code, event.reason);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State changed to:', state);
|
||||
});
|
||||
|
||||
client.on('message', (message) => {
|
||||
console.log('Received message:', message);
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Your operations here...
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 4: Multiple Subscriptions
|
||||
*/
|
||||
export async function multipleSubscriptionsExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
debug: true
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to users
|
||||
const userSubId = await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
console.log('[Users]', notification.operation, notification.data);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
// Subscribe to posts
|
||||
const postSubId = await client.subscribe(
|
||||
'posts',
|
||||
(notification) => {
|
||||
console.log('[Posts]', notification.operation, notification.data);
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'published' }
|
||||
]
|
||||
}
|
||||
);
|
||||
|
||||
// Subscribe to comments
|
||||
const commentSubId = await client.subscribe(
|
||||
'comments',
|
||||
(notification) => {
|
||||
console.log('[Comments]', notification.operation, notification.data);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
console.log('Active subscriptions:', client.getSubscriptions());
|
||||
|
||||
// Clean up after 60 seconds
|
||||
setTimeout(async () => {
|
||||
await client.unsubscribe(userSubId);
|
||||
await client.unsubscribe(postSubId);
|
||||
await client.unsubscribe(commentSubId);
|
||||
client.disconnect();
|
||||
}, 60000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 5: Advanced Queries
|
||||
*/
|
||||
export async function advancedQueriesExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Complex query with filters, sorting, pagination, and preloading
|
||||
const posts = await client.read('posts', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'published' },
|
||||
{ column: 'views', operator: 'gte', value: 100 }
|
||||
],
|
||||
columns: ['id', 'title', 'content', 'user_id', 'created_at'],
|
||||
sort: [
|
||||
{ column: 'created_at', direction: 'desc' },
|
||||
{ column: 'views', direction: 'desc' }
|
||||
],
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
},
|
||||
{
|
||||
relation: 'comments',
|
||||
columns: ['id', 'content', 'user_id'],
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'approved' }
|
||||
]
|
||||
}
|
||||
],
|
||||
limit: 20,
|
||||
offset: 0
|
||||
});
|
||||
|
||||
console.log('Posts:', posts);
|
||||
|
||||
// Get single record by ID
|
||||
const post = await client.read('posts', {
|
||||
schema: 'public',
|
||||
record_id: '123'
|
||||
});
|
||||
|
||||
console.log('Single post:', post);
|
||||
|
||||
client.disconnect();
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 6: Error Handling
|
||||
*/
|
||||
export async function errorHandlingExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
maxReconnectAttempts: 5
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Connection error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('Connection state:', state);
|
||||
});
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
|
||||
try {
|
||||
// Try to read non-existent entity
|
||||
await client.read('nonexistent', { schema: 'public' });
|
||||
} catch (error) {
|
||||
console.error('Read error:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
// Try to create invalid record
|
||||
await client.create('users', {
|
||||
// Missing required fields
|
||||
}, { schema: 'public' });
|
||||
} catch (error) {
|
||||
console.error('Create error:', error);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Connection failed:', error);
|
||||
} finally {
|
||||
client.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 7: React Integration
|
||||
*/
|
||||
export function reactIntegrationExample() {
|
||||
const exampleCode = `
|
||||
import { useEffect, useState } from 'react';
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
export function useWebSocket(url: string) {
|
||||
const [client] = useState(() => new WebSocketClient({ url }));
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
client.on('connect', () => setIsConnected(true));
|
||||
client.on('disconnect', () => setIsConnected(false));
|
||||
|
||||
client.connect();
|
||||
|
||||
return () => {
|
||||
client.disconnect();
|
||||
};
|
||||
}, [client]);
|
||||
|
||||
return { client, isConnected };
|
||||
}
|
||||
|
||||
export function UsersComponent() {
|
||||
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
|
||||
const [users, setUsers] = useState([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return;
|
||||
|
||||
// Subscribe to user changes
|
||||
const subscribeToUsers = async () => {
|
||||
const subId = await client.subscribe('users', (notification) => {
|
||||
if (notification.operation === 'create') {
|
||||
setUsers(prev => [...prev, notification.data]);
|
||||
} else if (notification.operation === 'update') {
|
||||
setUsers(prev => prev.map(u =>
|
||||
u.id === notification.data.id ? notification.data : u
|
||||
));
|
||||
} else if (notification.operation === 'delete') {
|
||||
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
|
||||
}
|
||||
}, { schema: 'public' });
|
||||
|
||||
// Load initial users
|
||||
const initialUsers = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
setUsers(initialUsers);
|
||||
|
||||
return () => client.unsubscribe(subId);
|
||||
};
|
||||
|
||||
subscribeToUsers();
|
||||
}, [client, isConnected]);
|
||||
|
||||
const createUser = async (name: string, email: string) => {
|
||||
await client.create('users', { name, email, status: 'active' }, {
|
||||
schema: 'public'
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Users ({users.length})</h2>
|
||||
{isConnected ? '🟢 Connected' : '🔴 Disconnected'}
|
||||
{/* Render users... */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
`;
|
||||
|
||||
console.log(exampleCode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 8: TypeScript with Typed Models
|
||||
*/
|
||||
export async function typedModelsExample() {
|
||||
// Define your models
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
status: 'active' | 'inactive';
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: number;
|
||||
title: string;
|
||||
content: string;
|
||||
user_id: number;
|
||||
status: 'draft' | 'published';
|
||||
views: number;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Type-safe operations
|
||||
const users = await client.read<User[]>('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
const newUser = await client.create<User>('users', {
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com',
|
||||
status: 'active'
|
||||
}, { schema: 'public' });
|
||||
|
||||
const posts = await client.read<Post[]>('posts', {
|
||||
schema: 'public',
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Type-safe subscriptions
|
||||
await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
const user = notification.data as User;
|
||||
console.log('User changed:', user.name, user.email);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
client.disconnect();
|
||||
}
|
||||
@@ -8,10 +8,22 @@ import type {
|
||||
WSOperation,
|
||||
WSOptions,
|
||||
Subscription,
|
||||
SubscriptionOptions,
|
||||
ConnectionState,
|
||||
WebSocketClientEvents
|
||||
} from './websocket-types';
|
||||
} from './types';
|
||||
import type { FilterOption, SortOption, PreloadOption } from '../common/types';
|
||||
|
||||
const instances = new Map<string, WebSocketClient>();
|
||||
|
||||
export function getWebSocketClient(config: WebSocketClientConfig): WebSocketClient {
|
||||
const key = config.url;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new WebSocketClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
export class WebSocketClient {
|
||||
private ws: WebSocket | null = null;
|
||||
@@ -36,9 +48,6 @@ export class WebSocketClient {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to WebSocket server
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.log('Already connected');
|
||||
@@ -78,7 +87,6 @@ export class WebSocketClient {
|
||||
this.setState('disconnected');
|
||||
this.emit('disconnect', event);
|
||||
|
||||
// Attempt reconnection if enabled and not manually closed
|
||||
if (this.config.reconnect && !this.isManualClose && this.reconnectAttempts < this.config.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
this.log(`Reconnection attempt ${this.reconnectAttempts}/${this.config.maxReconnectAttempts}`);
|
||||
@@ -97,9 +105,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from WebSocket server
|
||||
*/
|
||||
disconnect(): void {
|
||||
this.isManualClose = true;
|
||||
|
||||
@@ -120,9 +125,6 @@ export class WebSocketClient {
|
||||
this.messageHandlers.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a CRUD request and wait for response
|
||||
*/
|
||||
async request<T = any>(
|
||||
operation: WSOperation,
|
||||
entity: string,
|
||||
@@ -148,7 +150,6 @@ export class WebSocketClient {
|
||||
};
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
// Set up response handler
|
||||
this.messageHandlers.set(id, (response: WSResponseMessage) => {
|
||||
if (response.success) {
|
||||
resolve(response.data);
|
||||
@@ -157,10 +158,8 @@ export class WebSocketClient {
|
||||
}
|
||||
});
|
||||
|
||||
// Send message
|
||||
this.send(message);
|
||||
|
||||
// Timeout after 30 seconds
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -170,16 +169,13 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Read records
|
||||
*/
|
||||
async read<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
sort?: import('./types').SortOption[];
|
||||
preload?: import('./types').PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
preload?: PreloadOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
}): Promise<T> {
|
||||
@@ -197,9 +193,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a record
|
||||
*/
|
||||
async create<T = any>(entity: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -209,9 +202,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a record
|
||||
*/
|
||||
async update<T = any>(entity: string, id: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -222,9 +212,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a record
|
||||
*/
|
||||
async delete(entity: string, id: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<void> {
|
||||
@@ -234,9 +221,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get metadata for an entity
|
||||
*/
|
||||
async meta<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -245,15 +229,12 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to entity changes
|
||||
*/
|
||||
async subscribe(
|
||||
entity: string,
|
||||
callback: (notification: WSNotificationMessage) => void,
|
||||
options?: {
|
||||
schema?: string;
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
}
|
||||
): Promise<string> {
|
||||
this.ensureConnected();
|
||||
@@ -275,7 +256,6 @@ export class WebSocketClient {
|
||||
if (response.success && response.data?.subscription_id) {
|
||||
const subscriptionId = response.data.subscription_id;
|
||||
|
||||
// Store subscription
|
||||
this.subscriptions.set(subscriptionId, {
|
||||
id: subscriptionId,
|
||||
entity,
|
||||
@@ -293,7 +273,6 @@ export class WebSocketClient {
|
||||
|
||||
this.send(message);
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -303,9 +282,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Unsubscribe from entity changes
|
||||
*/
|
||||
async unsubscribe(subscriptionId: string): Promise<void> {
|
||||
this.ensureConnected();
|
||||
|
||||
@@ -330,7 +306,6 @@ export class WebSocketClient {
|
||||
|
||||
this.send(message);
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -340,37 +315,22 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of active subscriptions
|
||||
*/
|
||||
getSubscriptions(): Subscription[] {
|
||||
return Array.from(this.subscriptions.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection state
|
||||
*/
|
||||
getState(): ConnectionState {
|
||||
return this.state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if connected
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add event listener
|
||||
*/
|
||||
on<K extends keyof WebSocketClientEvents>(event: K, callback: WebSocketClientEvents[K]): void {
|
||||
this.eventListeners[event] = callback as any;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove event listener
|
||||
*/
|
||||
off<K extends keyof WebSocketClientEvents>(event: K): void {
|
||||
delete this.eventListeners[event];
|
||||
}
|
||||
@@ -384,7 +344,6 @@ export class WebSocketClient {
|
||||
|
||||
this.emit('message', message);
|
||||
|
||||
// Handle different message types
|
||||
switch (message.type) {
|
||||
case 'response':
|
||||
this.handleResponse(message as WSResponseMessage);
|
||||
@@ -395,7 +354,6 @@ export class WebSocketClient {
|
||||
break;
|
||||
|
||||
case 'pong':
|
||||
// Heartbeat response
|
||||
break;
|
||||
|
||||
default:
|
||||
2
resolvespec-js/src/websocketspec/index.ts
Normal file
2
resolvespec-js/src/websocketspec/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export * from './types';
|
||||
export { WebSocketClient, getWebSocketClient } from './client';
|
||||
@@ -1,17 +1,24 @@
|
||||
import type { FilterOption, SortOption, PreloadOption, Parameter } from '../common/types';
|
||||
|
||||
// Re-export common types
|
||||
export type { FilterOption, SortOption, PreloadOption, Operator, SortDirection } from '../common/types';
|
||||
|
||||
// WebSocket Message Types
|
||||
export type MessageType = 'request' | 'response' | 'notification' | 'subscription' | 'error' | 'ping' | 'pong';
|
||||
export type WSOperation = 'read' | 'create' | 'update' | 'delete' | 'subscribe' | 'unsubscribe' | 'meta';
|
||||
|
||||
// Re-export common types
|
||||
export type { FilterOption, SortOption, PreloadOption, Operator, SortDirection } from './types';
|
||||
|
||||
export interface WSOptions {
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
preload?: import('./types').PreloadOption[];
|
||||
sort?: import('./types').SortOption[];
|
||||
omit_columns?: string[];
|
||||
preload?: PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export interface WSMessage {
|
||||
@@ -78,7 +85,7 @@ export interface WSSubscriptionMessage {
|
||||
}
|
||||
|
||||
export interface SubscriptionOptions {
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
onNotification?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
21
resolvespec-js/tsconfig.json
Normal file
21
resolvespec-js/tsconfig.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src",
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"lib": ["ES2020", "DOM"]
|
||||
},
|
||||
"include": ["src"],
|
||||
"exclude": ["node_modules", "dist", "src/__tests__"]
|
||||
}
|
||||
20
resolvespec-js/vite.config.ts
Normal file
20
resolvespec-js/vite.config.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import { resolve } from 'path';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [
|
||||
dts({ rollupTypes: true }),
|
||||
],
|
||||
build: {
|
||||
lib: {
|
||||
entry: resolve(__dirname, 'src/index.ts'),
|
||||
name: 'ResolveSpec',
|
||||
formats: ['es', 'cjs'],
|
||||
fileName: (format) => `index.${format === 'es' ? 'js' : 'cjs'}`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['uuid', 'semver'],
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -1,5 +1,50 @@
|
||||
# Python Implementation of the ResolveSpec API
|
||||
# ResolveSpec Python Client - TODO
|
||||
|
||||
# Server
|
||||
## Client Implementation & Testing
|
||||
|
||||
# Client
|
||||
### 1. ResolveSpec Client API
|
||||
|
||||
- [ ] Core API implementation (read, create, update, delete, get_metadata)
|
||||
- [ ] Unit tests for API functions
|
||||
- [ ] Integration tests with server
|
||||
- [ ] Error handling and edge cases
|
||||
|
||||
### 2. HeaderSpec Client API
|
||||
|
||||
- [ ] Client API implementation
|
||||
- [ ] Unit tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
### 3. FunctionSpec Client API
|
||||
|
||||
- [ ] Client API implementation
|
||||
- [ ] Unit tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
### 4. WebSocketSpec Client API
|
||||
|
||||
- [ ] WebSocketClient class implementation (read, create, update, delete, meta, subscribe, unsubscribe)
|
||||
- [ ] Unit tests for WebSocketClient
|
||||
- [ ] Connection handling tests
|
||||
- [ ] Subscription tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
### 5. Testing Infrastructure
|
||||
|
||||
- [ ] Set up test framework (pytest)
|
||||
- [ ] Configure test coverage reporting (pytest-cov)
|
||||
- [ ] Add test utilities and fixtures
|
||||
- [ ] Create test documentation
|
||||
- [ ] Package and publish to PyPI
|
||||
|
||||
## Documentation
|
||||
|
||||
- [ ] API reference documentation
|
||||
- [ ] Usage examples for each client API
|
||||
- [ ] Installation guide
|
||||
- [ ] Contributing guidelines
|
||||
- [ ] README with quick start
|
||||
|
||||
---
|
||||
|
||||
**Last Updated:** 2026-02-07
|
||||
114
todo.md
114
todo.md
@@ -2,36 +2,98 @@
|
||||
|
||||
This document tracks incomplete features and improvements for the ResolveSpec project.
|
||||
|
||||
## In Progress
|
||||
|
||||
### Database Layer
|
||||
|
||||
- [x] SQLite schema translation (schema.table → schema_table)
|
||||
- [x] Driver name normalization across adapters
|
||||
- [x] Database Connection Manager (dbmanager) package
|
||||
|
||||
### Documentation
|
||||
- Ensure all new features are documented in README.md
|
||||
- Update examples to showcase new functionality
|
||||
- Add migration notes if any breaking changes are introduced
|
||||
|
||||
- [x] Add dbmanager to README
|
||||
- [x] Add WebSocketSpec to top-level intro
|
||||
- [x] Add MQTTSpec to top-level intro
|
||||
- [x] Remove migration sections from README
|
||||
- [ ] Complete API reference documentation
|
||||
- [ ] Add examples for all supported databases
|
||||
|
||||
### 8.
|
||||
## Planned Features
|
||||
|
||||
1. **Test Coverage**: Increase from 20% to 70%+
|
||||
- Add integration tests for CRUD operations
|
||||
- Add unit tests for security providers
|
||||
- Add concurrency tests for model registry
|
||||
### ResolveSpec JS Client Implementation & Testing
|
||||
|
||||
1. **ResolveSpec Client API (resolvespec-js)**
|
||||
- [x] Core API implementation (read, create, update, delete, getMetadata)
|
||||
- [ ] Unit tests for API functions
|
||||
- [ ] Integration tests with server
|
||||
- [ ] Error handling and edge cases
|
||||
|
||||
2. **HeaderSpec Client API (resolvespec-js)**
|
||||
- [ ] Client API implementation
|
||||
- [ ] Unit tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
3. **FunctionSpec Client API (resolvespec-js)**
|
||||
- [ ] Client API implementation
|
||||
- [ ] Unit tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
4. **WebSocketSpec Client API (resolvespec-js)**
|
||||
- [x] WebSocketClient class implementation (read, create, update, delete, meta, subscribe, unsubscribe)
|
||||
- [ ] Unit tests for WebSocketClient
|
||||
- [ ] Connection handling tests
|
||||
- [ ] Subscription tests
|
||||
- [ ] Integration tests with server
|
||||
|
||||
5. **resolvespec-js Testing Infrastructure**
|
||||
- [ ] Set up test framework (Jest or Vitest)
|
||||
- [ ] Configure test coverage reporting
|
||||
- [ ] Add test utilities and mocks
|
||||
- [ ] Create test documentation
|
||||
|
||||
### ResolveSpec Python Client Implementation & Testing
|
||||
|
||||
See [`resolvespec-python/todo.md`](./resolvespec-python/todo.md) for detailed Python client implementation tasks.
|
||||
|
||||
### Core Functionality
|
||||
|
||||
1. **Enhanced Preload Filtering**
|
||||
- [ ] Column selection for nested preloads
|
||||
- [ ] Advanced filtering conditions for relations
|
||||
- [ ] Performance optimization for deep nesting
|
||||
|
||||
2. **Advanced Query Features**
|
||||
- [ ] Custom SQL join support
|
||||
- [ ] Computed column improvements
|
||||
- [ ] Recursive query support
|
||||
|
||||
3. **Testing & Quality**
|
||||
- [ ] Increase test coverage to 70%+
|
||||
- [ ] Add integration tests for all ORMs
|
||||
- [ ] Add concurrency tests for thread safety
|
||||
- [ ] Performance benchmarks
|
||||
|
||||
### Infrastructure
|
||||
|
||||
- [ ] Improved error handling and reporting
|
||||
- [ ] Enhanced logging capabilities
|
||||
- [ ] Additional monitoring metrics
|
||||
- [ ] Performance profiling tools
|
||||
|
||||
## Documentation Tasks
|
||||
|
||||
- [ ] Complete API reference
|
||||
- [ ] Add troubleshooting guides
|
||||
- [ ] Create architecture diagrams
|
||||
- [ ] Expand database adapter documentation
|
||||
|
||||
## Known Issues
|
||||
|
||||
- [ ] Long preload alias names may exceed PostgreSQL identifier limit
|
||||
- [ ] Some edge cases in computed column handling
|
||||
|
||||
---
|
||||
|
||||
## Priority Ranking
|
||||
|
||||
1. **High Priority**
|
||||
- Column Selection and Filtering for Preloads (#1)
|
||||
- Proper Condition Handling for Bun Preloads (#4)
|
||||
|
||||
2. **Medium Priority**
|
||||
- Custom SQL Join Support (#3)
|
||||
- Recursive JSON Cleaning (#2)
|
||||
|
||||
3. **Low Priority**
|
||||
- Modernize Go Type Declarations (#5)
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
**Last Updated:** 2025-12-09
|
||||
**Last Updated:** 2026-02-07
|
||||
**Updated:** Added resolvespec-js client testing and implementation tasks
|
||||
|
||||
Reference in New Issue
Block a user