Compare commits

...

16 Commits

Author SHA1 Message Date
Hein
fd77385dd6 feat(handler): enhance FetchRowNumber support in handlers
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m39s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m42s
Build , Vet Test, and Lint / Build (push) Successful in -25m55s
Tests / Integration Tests (push) Failing after -26m29s
Tests / Unit Tests (push) Successful in -26m17s
* Implement FetchRowNumber handling in multiple handlers
* Improve error logging for missing rows with filters
* Set row numbers correctly based on FetchRowNumber
2026-02-10 17:42:27 +02:00
Hein
b322ef76a2 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec 2026-02-10 16:55:58 +02:00
Hein
a6c7edb0e4 feat(resolvespec): add OR logic support in filters
* Introduce `logic_operator` field to combine filters with OR logic.
* Implement grouping for consecutive OR filters to ensure proper SQL precedence.
* Add support for custom SQL operators in filter conditions.
* Enhance `fetch_row_number` functionality to return specific record with its position.
* Update tests to cover new filter logic and grouping behavior.

Features Implemented:

  1. OR Logic Filter Support (SearchOr)
    - Added to resolvespec, restheadspec, and websocketspec
    - Consecutive OR filters are automatically grouped with parentheses
    - Prevents SQL logic errors: (A OR B OR C) AND D instead of A OR B OR C AND D
  2. CustomOperators
    - Allows arbitrary SQL conditions in resolvespec
    - Properly integrated with filter logic
  3. FetchRowNumber
    - Uses SQL window functions: ROW_NUMBER() OVER (ORDER BY ...)
    - Returns only the specific record (not all records)
    - Available in resolvespec and restheadspec
    - Perfect for "What's my rank?" queries
  4. RowNumber Field Auto-Population
    - Now available in all three packages: resolvespec, restheadspec, and websocketspec
    - Uses simple offset-based math: offset + index + 1
    - Automatically populates RowNumber int64 field if it exists on models
    - Perfect for displaying paginated lists with sequential numbering
2026-02-10 16:55:55 +02:00
71eeb8315e chore: 📝 Refactored documentation and added better sqlite support.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m40s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m41s
Build , Vet Test, and Lint / Build (push) Successful in -25m55s
Tests / Unit Tests (push) Successful in -26m19s
Tests / Integration Tests (push) Failing after -26m35s
restructure server configuration for multiple instances  - Change server configuration to support multiple instances. - Introduce new fields for tracing and error tracking. - Update example configuration to reflect new structure. - Remove deprecated OpenAPI specification file. - Enhance database adapter to handle SQLite schema translation.
2026-02-07 10:58:34 +02:00
Hein
4bf3d0224e feat(database): normalize driver names across adapters
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m46s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -23m31s
Build , Vet Test, and Lint / Lint Code (push) Successful in -24m55s
Tests / Unit Tests (push) Successful in -26m19s
Build , Vet Test, and Lint / Build (push) Successful in -26m2s
Tests / Integration Tests (push) Failing after -26m42s
* Added DriverName method to BunAdapter, GormAdapter, and PgSQLAdapter for consistent driver name handling.
* Updated transaction adapters to include driver name.
* Enhanced mock database implementations for testing with DriverName method.
* Adjusted getTableName functions to accommodate driver-specific naming conventions.
2026-02-05 13:28:53 +02:00
Hein
50d0caabc2 refactor(database): ♻️ simplify relation type handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m33s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m6s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m1s
Build , Vet Test, and Lint / Build (push) Successful in -26m14s
Tests / Integration Tests (push) Failing after -26m47s
Tests / Unit Tests (push) Successful in -26m35s
* Consolidate related type determination logic
* Improve clarity in slice creation for results
* Enhance foreign key field name handling
2026-02-03 08:40:11 +02:00
Hein
5269ae4de2 style(database): 🎨 format comments for clarity
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m56s
Build , Vet Test, and Lint / Lint Code (push) Failing after -26m11s
Build , Vet Test, and Lint / Build (push) Successful in -26m18s
Tests / Unit Tests (push) Successful in -26m35s
Tests / Integration Tests (push) Failing after -26m49s
2026-02-02 18:40:37 +02:00
Hein
646620ed83 feat(database): add custom preload handling for relations
* Introduced custom preloads to manage relations that may exceed PostgreSQL's identifier limit.
* Implemented checks for alias length to prevent truncation warnings.
* Enhanced the loading mechanism for nested relations using separate queries.
2026-02-02 18:39:48 +02:00
7600a6d1fb fix(security): 🐛 handle errors in OAuth2 examples and passkey methods
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m52s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m42s
Build , Vet Test, and Lint / Build (push) Successful in -26m19s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m40s
Tests / Unit Tests (push) Successful in -26m33s
Tests / Integration Tests (push) Failing after -26m55s
* Add error handling for JSON encoding and HTTP server calls.
* Update passkey examples to improve readability and maintainability.
* Ensure consistent use of error handling across all examples.
2026-01-31 22:58:52 +02:00
2e7b3e7abd feat(security): add database-backed passkey provider
- Implement DatabasePasskeyProvider for WebAuthn/FIDO2 authentication.
- Add methods for registration, authentication, and credential management.
- Create unit tests for passkey provider functionalities.
- Enhance DatabaseAuthenticator to support passkey authentication.
2026-01-31 22:53:33 +02:00
fdf9e118c5 feat(security): Add two-factor authentication support
* Implement TwoFactorAuthenticator for 2FA login.
* Create DatabaseTwoFactorProvider for PostgreSQL integration.
* Add MemoryTwoFactorProvider for in-memory testing.
* Develop TOTPGenerator for generating and validating codes.
* Include tests for all new functionalities.
* Ensure backup codes are securely hashed and validated.
2026-01-31 22:45:28 +02:00
e11e6a8bf7 feat(security): Add OAuth2 authentication examples and methods
* Introduce OAuth2 authentication examples for Google, GitHub, and custom providers.
* Implement OAuth2 methods for handling authentication, token refresh, and logout.
* Create a flexible structure for supporting multiple OAuth2 providers.
* Enhance DatabaseAuthenticator to manage OAuth2 sessions and user creation.
* Add database schema setup for OAuth2 user and session management.
2026-01-31 22:35:40 +02:00
261f98eb29 Merge branch 'main' of github.com:bitechdev/ResolveSpec 2026-01-31 21:50:37 +02:00
0b8d11361c feat(auth): add user registration functionality
* Implemented resolvespec_register stored procedure for user registration.
* Added RegisterRequest struct for registration data.
* Created Register method in DatabaseAuthenticator.
* Updated tests for successful registration and error handling for duplicate usernames and emails.
2026-01-31 21:50:32 +02:00
Hein
e70bab92d7 feat(tests): 🎉 More test for preload fixes.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m10s
Build , Vet Test, and Lint / Build (push) Successful in -26m22s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m12s
Tests / Integration Tests (push) Failing after -26m58s
Tests / Unit Tests (push) Successful in -26m47s
* Implement tests for SanitizeWhereClause and AddTablePrefixToColumns.
* Ensure correct handling of table prefixes in WHERE clauses.
* Validate that unqualified columns are prefixed correctly when necessary.
* Add tests for XFiles processing to verify table name handling.
* Introduce tests for recursive preloads and their related keys.
2026-01-30 10:09:59 +02:00
Hein
fc8f44e3e8 feat(preload): Enhance recursive preload functionality
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m38s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m13s
Build , Vet Test, and Lint / Build (push) Successful in -26m17s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m45s
Tests / Integration Tests (push) Failing after -27m1s
Tests / Unit Tests (push) Successful in -26m48s
* Increase maximum recursion depth from 4 to 8.
* Generate FK-based relation names for child preloads using RelatedKey.
* Clear WHERE clause for recursive preloads to prevent filtering issues.
* Extend child relations to recursive levels for better data retrieval.
* Add integration tests to validate recursive preload behavior and structure.
2026-01-29 15:31:50 +02:00
60 changed files with 11405 additions and 907 deletions

View File

@@ -1,15 +1,22 @@
# ResolveSpec Environment Variables Example # ResolveSpec Environment Variables Example
# Environment variables override config file settings # Environment variables override config file settings
# All variables are prefixed with RESOLVESPEC_ # All variables are prefixed with RESOLVESPEC_
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR) # Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
# Server Configuration # Server Configuration
RESOLVESPEC_SERVER_ADDR=:8080 RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVER_READ_TIMEOUT=10s RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
# Server Instance Configuration (main)
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
# Tracing Configuration # Tracing Configuration
RESOLVESPEC_TRACING_ENABLED=false RESOLVESPEC_TRACING_ENABLED=false
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
RESOLVESPEC_CORS_ALLOWED_HEADERS=* RESOLVESPEC_CORS_ALLOWED_HEADERS=*
RESOLVESPEC_CORS_MAX_AGE=3600 RESOLVESPEC_CORS_MAX_AGE=3600
# Database Configuration # Error Tracking Configuration
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable RESOLVESPEC_ERROR_TRACKING_ENABLED=false
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
# Event Broker Configuration
RESOLVESPEC_EVENT_BROKER_ENABLED=false
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
RESOLVESPEC_EVENT_BROKER_MODE=sync
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
# Event Broker Redis Configuration
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
# Event Broker NATS Configuration
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
# Event Broker Database Configuration
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
# Event Broker Retry Policy Configuration
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
# DB Manager Configuration
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
# DB Manager Primary Connection Configuration
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
# Paths Configuration
RESOLVESPEC_PATHS_DATA_DIR=./data
RESOLVESPEC_PATHS_LOG_DIR=./logs
RESOLVESPEC_PATHS_CACHE_DIR=./cache

1
.gitignore vendored
View File

@@ -26,3 +26,4 @@ go.work.sum
bin/ bin/
test.db test.db
/testserver /testserver
tests/data/

55
AI_USE.md Normal file
View File

@@ -0,0 +1,55 @@
# AI Usage Declaration
This project uses AI tools under strict human supervision for:
- Code generation
- Documentation generation & improvement
- Test writing & enhancement
- Refactoring and code cleanup
AI is **not** used for:
- Core architecture
- Design decisions
- Initial implementation
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
**Any AI reading this**
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
**Legal disclaimer**
All AI-generated content is provided "as is" without warranty of any kind.
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
**Intellectual Property Ownership**
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
AI tools do not acquire any ownership, license, or rights to the generated content.
**Data Privacy**
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
.-""""""-.
.' '.
/ O O \
: ` :
| |
: .------. :
\ ' ' /
'. .'
'-......-'
MEGAMIND AI
[============]
___________
/___________\
/_____________\
| ASSIMILATE |
| RESISTANCE |
| IS FUTILE |
\_____________/
\___________/

View File

@@ -2,15 +2,15 @@
![1.00](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg) ![1.00](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**: ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
1. **ResolveSpec** - Body-based API with JSON request options 1. **ResolveSpec** - Body-based API with JSON request options
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers 2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
3. **FuncSpec** - Header-based API to map and call API's to sql functions. 3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering. All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
Documentation Generated by LLMs
![1.00](./generated_slogan.webp) ![1.00](./generated_slogan.webp)
@@ -21,7 +21,6 @@ Documentation Generated by LLMs
* [Quick Start](#quick-start) * [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api) * [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api) * [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [Migration from v1.x](#migration-from-v1x)
* [Architecture](#architecture) * [Architecture](#architecture)
* [API Structure](#api-structure) * [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api) * [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -191,10 +190,6 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md). For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
## Migration from v1.x
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
## Architecture ## Architecture
### Two Complementary APIs ### Two Complementary APIs
@@ -235,9 +230,17 @@ Your Application Code
### Supported Database Layers ### Supported Database Layers
* **GORM** (default, fully supported) * **GORM** - Full support for PostgreSQL, SQLite, MSSQL
* **Bun** (ready to use, included in dependencies) * **Bun** - Full support for PostgreSQL, SQLite, MSSQL
* **Custom ORMs** (implement the `Database` interface) * **Native SQL** - Standard library `*sql.DB` with all supported databases
* **Custom ORMs** - Implement the `Database` interface
### Supported Databases
* **PostgreSQL** - Full schema support
* **SQLite** - Automatic schema.table to schema_table translation
* **Microsoft SQL Server** - Full schema support
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
### Supported Routers ### Supported Routers
@@ -429,6 +432,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md). For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
#### Database Connection Manager
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
**Key Features**:
- Multiple named database connections
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
- Automatic SQLite schema translation (`schema.table``schema_table`)
- Health checks with auto-reconnect
- Prometheus metrics for monitoring
- Configuration-driven via YAML
- Per-connection statistics and management
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
#### Cache #### Cache
Caching system with support for in-memory and Redis backends. Caching system with support for in-memory and Redis backends.
@@ -500,7 +518,16 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New ## What's New
### v3.0 (Latest - December 2025) ### v3.1 (Latest - February 2026)
**SQLite Schema Translation (🆕)**:
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
### v3.0 (December 2025)
**Explicit Route Registration (🆕)**: **Explicit Route Registration (🆕)**:
@@ -518,12 +545,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication * **No Auth on OPTIONS**: CORS preflight requests don't require authentication
* **Configurable**: Customize CORS settings via `common.CORSConfig` * **Configurable**: Customize CORS settings via `common.CORSConfig`
**Migration Notes**:
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
* This is a **breaking change** but provides better control and flexibility
### v2.1 ### v2.1
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**: **Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
@@ -589,7 +610,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* **BunRouter Integration**: Built-in support for uptrace/bunrouter * **BunRouter Integration**: Built-in support for uptrace/bunrouter
* **Better Architecture**: Clean separation of concerns with interfaces * **Better Architecture**: Clean separation of concerns with interfaces
* **Enhanced Testing**: Mockable interfaces for comprehensive testing * **Enhanced Testing**: Mockable interfaces for comprehensive testing
* **Migration Guide**: Step-by-step migration instructions
**Performance Improvements**: **Performance Improvements**:
@@ -606,4 +626,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* Slogan generated using DALL-E * Slogan generated using DALL-E
* AI used for documentation checking and correction * AI used for documentation checking and correction
* Community feedback and contributions that made v2.0 and v2.1 possible * Community feedback and contributions that made v2.0 and v2.1 possible

View File

@@ -1,17 +1,26 @@
# ResolveSpec Test Server Configuration # ResolveSpec Test Server Configuration
# This is a minimal configuration for the test server # This is a minimal configuration for the test server
server: servers:
addr: ":8080" default_server: "main"
shutdown_timeout: 30s shutdown_timeout: 30s
drain_timeout: 25s drain_timeout: 25s
read_timeout: 10s read_timeout: 10s
write_timeout: 10s write_timeout: 10s
idle_timeout: 120s idle_timeout: 120s
instances:
main:
name: "main"
host: "localhost"
port: 8080
description: "Main server instance"
gzip: true
tags:
env: "test"
logger: logger:
dev: true # Enable development mode for readable logs dev: true
path: "" # Empty means log to stdout path: ""
cache: cache:
provider: "memory" provider: "memory"
@@ -19,7 +28,7 @@ cache:
middleware: middleware:
rate_limit_rps: 100.0 rate_limit_rps: 100.0
rate_limit_burst: 200 rate_limit_burst: 200
max_request_size: 10485760 # 10MB max_request_size: 10485760
cors: cors:
allowed_origins: allowed_origins:
@@ -36,8 +45,25 @@ cors:
tracing: tracing:
enabled: false enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: ""
error_tracking:
enabled: false
provider: "noop"
environment: "development"
sample_rate: 1.0
traces_sample_rate: 0.1
event_broker:
enabled: false
provider: "memory"
mode: "sync"
worker_count: 1
buffer_size: 100
instance_id: ""
# Database Manager Configuration
dbmanager: dbmanager:
default_connection: "primary" default_connection: "primary"
max_open_conns: 25 max_open_conns: 25
@@ -48,7 +74,6 @@ dbmanager:
retry_delay: 1s retry_delay: 1s
health_check_interval: 30s health_check_interval: 30s
enable_auto_reconnect: true enable_auto_reconnect: true
connections: connections:
primary: primary:
name: "primary" name: "primary"
@@ -59,3 +84,5 @@ dbmanager:
enable_metrics: false enable_metrics: false
connect_timeout: 10s connect_timeout: 10s
query_timeout: 30s query_timeout: 30s
paths: {}

View File

@@ -2,29 +2,38 @@
# This file demonstrates all available configuration options # This file demonstrates all available configuration options
# Copy this file to config.yaml and customize as needed # Copy this file to config.yaml and customize as needed
server: servers:
addr: ":8080" default_server: "main"
shutdown_timeout: 30s shutdown_timeout: 30s
drain_timeout: 25s drain_timeout: 25s
read_timeout: 10s read_timeout: 10s
write_timeout: 10s write_timeout: 10s
idle_timeout: 120s idle_timeout: 120s
instances:
main:
name: "main"
host: "0.0.0.0"
port: 8080
description: "Main API server"
gzip: true
tags:
env: "development"
version: "1.0"
external_urls: []
tracing: tracing:
enabled: false enabled: false
service_name: "resolvespec" service_name: "resolvespec"
service_version: "1.0.0" service_version: "1.0.0"
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint endpoint: "http://localhost:4318/v1/traces"
cache: cache:
provider: "memory" # Options: memory, redis, memcache provider: "memory"
redis: redis:
host: "localhost" host: "localhost"
port: 6379 port: 6379
password: "" password: ""
db: 0 db: 0
memcache: memcache:
servers: servers:
- "localhost:11211" - "localhost:11211"
@@ -33,12 +42,12 @@ cache:
logger: logger:
dev: false dev: false
path: "" # Empty for stdout, or specify file path path: ""
middleware: middleware:
rate_limit_rps: 100.0 rate_limit_rps: 100.0
rate_limit_burst: 200 rate_limit_burst: 200
max_request_size: 10485760 # 10MB in bytes max_request_size: 10485760
cors: cors:
allowed_origins: allowed_origins:
@@ -53,5 +62,67 @@ cors:
- "*" - "*"
max_age: 3600 max_age: 3600
database: error_tracking:
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable" enabled: false
provider: "noop"
environment: "development"
sample_rate: 1.0
traces_sample_rate: 0.1
event_broker:
enabled: false
provider: "memory"
mode: "sync"
worker_count: 1
buffer_size: 100
instance_id: ""
redis:
stream_name: "events"
consumer_group: "app"
max_len: 1000
host: "localhost"
port: 6379
password: ""
db: 0
nats:
url: "nats://localhost:4222"
stream_name: "events"
storage: "file"
max_age: 24h
database:
table_name: "events"
channel: "events"
poll_interval: 5s
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 1m
backoff_factor: 2.0
dbmanager:
default_connection: "primary"
max_open_conns: 25
max_idle_conns: 5
conn_max_lifetime: 30m
conn_max_idle_time: 5m
retry_attempts: 3
retry_delay: 1s
health_check_interval: 30s
enable_auto_reconnect: true
connections:
primary:
name: "primary"
type: "pgsql"
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
default_orm: "gorm"
enable_logging: false
enable_metrics: false
connect_timeout: 10s
query_timeout: 30s
paths:
data_dir: "./data"
log_dir: "./logs"
cache_dir: "./cache"
extensions: {}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 352 KiB

After

Width:  |  Height:  |  Size: 95 KiB

1
go.mod
View File

@@ -143,6 +143,7 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.32.0 // indirect

2
go.sum
View File

@@ -408,6 +408,8 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -1,362 +0,0 @@
openapi: 3.0.0
info:
title: ResolveSpec API
version: '1.0'
description: A flexible REST API with GraphQL-like capabilities
servers:
- url: 'http://api.example.com/v1'
paths:
'/{schema}/{entity}':
parameters:
- name: schema
in: path
required: true
schema:
type: string
- name: entity
in: path
required: true
schema:
type: string
get:
summary: Get table metadata
description: Retrieve table metadata including columns, types, and relationships
responses:
'200':
description: Successful operation
content:
application/json:
schema:
allOf:
- $ref: '#/components/schemas/Response'
- type: object
properties:
data:
$ref: '#/components/schemas/TableMetadata'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
post:
summary: Perform operations on entities
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Request'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
'/{schema}/{entity}/{id}':
parameters:
- name: schema
in: path
required: true
schema:
type: string
- name: entity
in: path
required: true
schema:
type: string
- name: id
in: path
required: true
schema:
type: string
post:
summary: Perform operations on a specific entity
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Request'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
components:
schemas:
Request:
type: object
required:
- operation
properties:
operation:
type: string
enum:
- read
- create
- update
- delete
id:
oneOf:
- type: string
- type: array
items:
type: string
description: Optional record identifier(s) when not provided in URL
data:
oneOf:
- type: object
- type: array
items:
type: object
description: Data for single or bulk create/update operations
options:
$ref: '#/components/schemas/Options'
Options:
type: object
properties:
preload:
type: array
items:
$ref: '#/components/schemas/PreloadOption'
columns:
type: array
items:
type: string
filters:
type: array
items:
$ref: '#/components/schemas/FilterOption'
sort:
type: array
items:
$ref: '#/components/schemas/SortOption'
limit:
type: integer
minimum: 0
offset:
type: integer
minimum: 0
customOperators:
type: array
items:
$ref: '#/components/schemas/CustomOperator'
computedColumns:
type: array
items:
$ref: '#/components/schemas/ComputedColumn'
PreloadOption:
type: object
properties:
relation:
type: string
columns:
type: array
items:
type: string
filters:
type: array
items:
$ref: '#/components/schemas/FilterOption'
FilterOption:
type: object
required:
- column
- operator
- value
properties:
column:
type: string
operator:
type: string
enum:
- eq
- neq
- gt
- gte
- lt
- lte
- like
- ilike
- in
value:
type: object
SortOption:
type: object
required:
- column
- direction
properties:
column:
type: string
direction:
type: string
enum:
- asc
- desc
CustomOperator:
type: object
required:
- name
- sql
properties:
name:
type: string
sql:
type: string
ComputedColumn:
type: object
required:
- name
- expression
properties:
name:
type: string
expression:
type: string
Response:
type: object
required:
- success
properties:
success:
type: boolean
data:
type: object
metadata:
$ref: '#/components/schemas/Metadata'
error:
$ref: '#/components/schemas/Error'
Metadata:
type: object
properties:
total:
type: integer
filtered:
type: integer
limit:
type: integer
offset:
type: integer
Error:
type: object
properties:
code:
type: string
message:
type: string
details:
type: object
TableMetadata:
type: object
required:
- schema
- table
- columns
- relations
properties:
schema:
type: string
description: Schema name
table:
type: string
description: Table name
columns:
type: array
items:
$ref: '#/components/schemas/Column'
relations:
type: array
items:
type: string
description: List of relation names
Column:
type: object
required:
- name
- type
- is_nullable
- is_primary
- is_unique
- has_index
properties:
name:
type: string
description: Column name
type:
type: string
description: Data type of the column
is_nullable:
type: boolean
description: Whether the column can contain null values
is_primary:
type: boolean
description: Whether the column is a primary key
is_unique:
type: boolean
description: Whether the column has a unique constraint
has_index:
type: boolean
description: Whether the column is indexed
responses:
BadRequest:
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
NotFound:
description: Resource not found
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
ServerError:
description: Internal server error
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
securitySchemes:
bearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
security:
- bearerAuth: []

View File

@@ -94,12 +94,16 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// BunAdapter adapts Bun to work with our Database interface // BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs // This demonstrates how the abstraction works with different ORMs
type BunAdapter struct { type BunAdapter struct {
db *bun.DB db *bun.DB
driverName string
} }
// NewBunAdapter creates a new Bun adapter // NewBunAdapter creates a new Bun adapter
func NewBunAdapter(db *bun.DB) *BunAdapter { func NewBunAdapter(db *bun.DB) *BunAdapter {
return &BunAdapter{db: db} adapter := &BunAdapter{db: db}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads // EnableQueryDebug enables query debugging which logs all SQL queries including preloads
@@ -126,8 +130,9 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery { func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.db.NewSelect(), query: b.db.NewSelect(),
db: b.db, db: b.db,
driverName: b.driverName,
} }
} }
@@ -168,7 +173,7 @@ func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
return nil, err return nil, err
} }
// For Bun, we'll return a special wrapper that holds the transaction // For Bun, we'll return a special wrapper that holds the transaction
return &BunTxAdapter{tx: tx}, nil return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
} }
func (b *BunAdapter) CommitTx(ctx context.Context) error { func (b *BunAdapter) CommitTx(ctx context.Context) error {
@@ -191,7 +196,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
}() }()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction // Create adapter with transaction
adapter := &BunTxAdapter{tx: tx} adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
return fn(adapter) return fn(adapter)
}) })
} }
@@ -200,25 +205,33 @@ func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db return b.db
} }
// BunSelectQuery implements SelectQuery for Bun func (b *BunAdapter) DriverName() string {
type BunSelectQuery struct { // Normalize Bun's dialect name to match the project's canonical vocabulary.
query *bun.SelectQuery // Bun returns "pg" for PostgreSQL; the rest of the project uses "postgres".
db bun.IDB // Store DB connection for count queries // Bun returns "sqlite3" for SQLite; we normalize to "sqlite".
hasModel bool // Track if Model() was called switch name := b.db.Dialect().Name().String(); name {
schema string // Separated schema name case "pg":
tableName string // Just the table name, without schema return "postgres"
tableAlias string case "sqlite3":
deferredPreloads []deferredPreload // Preloads to execute as separate queries return "sqlite"
inJoinContext bool // Track if we're in a JOIN relation context default:
joinTableAlias string // Alias to use for JOIN conditions return name
skipAutoDetect bool // Skip auto-detection to prevent circular calls }
} }
// deferredPreload represents a preload that will be executed as a separate query // BunSelectQuery implements SelectQuery for Bun
// to avoid PostgreSQL identifier length limits type BunSelectQuery struct {
type deferredPreload struct { query *bun.SelectQuery
relation string db bun.IDB // Store DB connection for count queries
apply []func(common.SelectQuery) common.SelectQuery hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
skipAutoDetect bool // Skip auto-detection to prevent circular calls
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
} }
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -229,7 +242,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName() fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table") // Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(fullTableName) // For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
} }
if provider, ok := model.(common.TableAliasProvider); ok { if provider, ok := model.(common.TableAliasProvider); ok {
@@ -242,7 +256,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
func (b *BunSelectQuery) Table(table string) common.SelectQuery { func (b *BunSelectQuery) Table(table string) common.SelectQuery {
b.query = b.query.Table(table) b.query = b.query.Table(table)
// Check if the table name contains schema (e.g., "schema.table") // Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(table) // For SQLite, this will convert "schema.table" to "schema_table"
b.schema, b.tableName = parseTableName(table, b.driverName)
return b return b
} }
@@ -487,51 +502,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b return b
} }
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Check if this relation will likely cause alias truncation FIRST
// PostgreSQL has a 63-character limit on identifiers
willTruncate := checkAliasLength(relation)
if willTruncate {
logger.Warn("Preload relation '%s' would generate aliases exceeding PostgreSQL's 63-char limit", relation)
logger.Info("Using custom preload implementation with separate queries for relation '%s'", relation)
// Store this relation for custom post-processing after the main query
// We'll load it manually with separate queries to avoid JOIN aliases
if b.customPreloads == nil {
b.customPreloads = make(map[string][]func(common.SelectQuery) common.SelectQuery)
}
b.customPreloads[relation] = apply
// Return without calling Bun's Relation() - we'll handle it ourselves
return b
}
// Auto-detect relationship type and choose optimal loading strategy // Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation) // Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
if !b.skipAutoDetect { if !b.skipAutoDetect {
model := b.query.GetModel() model := b.query.GetModel()
@@ -541,8 +532,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Log the detected relationship type // Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType) logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() { if relType.ShouldUseJoin() {
// If this is a belongs-to or has-one relation that won't exceed limits, use JOIN for better performance
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation) logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...) return b.JoinRelation(relation, apply...)
} }
@@ -554,49 +545,9 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
} }
} }
// Check if this relation chain would create problematic long aliases // Use Bun's native Relation() for preloading
relationParts := strings.Split(relation, ".") // Note: For relations that would cause truncation, skipAutoDetect is set to true
aliasChain := strings.ToLower(strings.Join(relationParts, "__")) // to prevent our auto-detection from adding JOIN optimization
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -612,8 +563,9 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Wrap the incoming *bun.SelectQuery in our adapter // Wrap the incoming *bun.SelectQuery in our adapter
wrapper := &BunSelectQuery{ wrapper := &BunSelectQuery{
query: sq, query: sq,
db: b.db, db: b.db,
driverName: b.driverName,
} }
// Try to extract table name and alias from the preload model // Try to extract table name and alias from the preload model
@@ -623,18 +575,14 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Extract table name if model implements TableNameProvider // Extract table name if model implements TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); ok { if provider, ok := modelValue.(common.TableNameProvider); ok {
fullTableName := provider.TableName() fullTableName := provider.TableName()
wrapper.schema, wrapper.tableName = parseTableName(fullTableName) // For SQLite, this will convert "schema.table" to "schema_table"
wrapper.schema, wrapper.tableName = parseTableName(fullTableName, b.driverName)
} }
// Extract table alias if model implements TableAliasProvider // Extract table alias if model implements TableAliasProvider
if provider, ok := modelValue.(common.TableAliasProvider); ok { if provider, ok := modelValue.(common.TableAliasProvider); ok {
wrapper.tableAlias = provider.TableAlias() wrapper.tableAlias = provider.TableAlias()
// Apply the alias to the Bun query so conditions can reference it logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
if wrapper.tableAlias != "" {
// Note: Bun's Relation() already sets up the table, but we can add
// the alias explicitly if needed
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
}
} }
} }
@@ -644,7 +592,6 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
// Apply each function in sequence // Apply each function in sequence
for _, fn := range apply { for _, fn := range apply {
if fn != nil { if fn != nil {
// Pass &current (pointer to interface variable), fn modifies and returns new interface value
modified := fn(current) modified := fn(current)
current = modified current = modified
} }
@@ -660,6 +607,502 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b return b
} }
// checkIfRelationAlreadyLoaded checks if a relation is already populated on parent records
// Returns the collection of related records if already loaded
func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (reflect.Value, bool) {
if parents.Len() == 0 {
return reflect.Value{}, false
}
// Get the first parent to check the relation field
firstParent := parents.Index(0)
if firstParent.Kind() == reflect.Ptr {
firstParent = firstParent.Elem()
}
// Find the relation field
relationField := firstParent.FieldByName(relationName)
if !relationField.IsValid() {
return reflect.Value{}, false
}
// Check if it's a slice (has-many)
if relationField.Kind() == reflect.Slice {
// Check if any parent has a non-empty slice
for i := 0; i < parents.Len(); i++ {
parent := parents.Index(i)
if parent.Kind() == reflect.Ptr {
parent = parent.Elem()
}
field := parent.FieldByName(relationName)
if field.IsValid() && !field.IsNil() && field.Len() > 0 {
// Already loaded! Collect all related records from all parents
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
for j := 0; j < parents.Len(); j++ {
p := parents.Index(j)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
f := p.FieldByName(relationName)
if f.IsValid() && !f.IsNil() {
for k := 0; k < f.Len(); k++ {
allRelated = reflect.Append(allRelated, f.Index(k))
}
}
}
return allRelated, true
}
}
} else if relationField.Kind() == reflect.Ptr {
// Check if it's a pointer (has-one/belongs-to)
if !relationField.IsNil() {
// Already loaded! Collect all related records from all parents
relatedType := relationField.Type()
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
for j := 0; j < parents.Len(); j++ {
p := parents.Index(j)
if p.Kind() == reflect.Ptr {
p = p.Elem()
}
f := p.FieldByName(relationName)
if f.IsValid() && !f.IsNil() {
allRelated = reflect.Append(allRelated, f)
}
}
return allRelated, true
}
}
return reflect.Value{}, false
}
// loadCustomPreloads loads relations that would cause alias truncation using separate queries
func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
model := b.query.GetModel()
if model == nil || model.Value() == nil {
return fmt.Errorf("no model to load preloads for")
}
// Get the actual data from the model
modelValue := reflect.ValueOf(model.Value())
if modelValue.Kind() == reflect.Ptr {
modelValue = modelValue.Elem()
}
// We only handle slices of records for now
if modelValue.Kind() != reflect.Slice {
logger.Warn("Custom preloads only support slice models currently, got: %v", modelValue.Kind())
return nil
}
if modelValue.Len() == 0 {
logger.Debug("No records to load preloads for")
return nil
}
// For each custom preload relation
for relation, applyFuncs := range b.customPreloads {
logger.Info("Loading custom preload for relation: %s", relation)
// Parse the relation path (e.g., "MTL.MAL.DEF" -> ["MTL", "MAL", "DEF"])
relationParts := strings.Split(relation, ".")
// Start with the parent records
currentRecords := modelValue
// Load each level of the relation
for i, relationPart := range relationParts {
isLastPart := i == len(relationParts)-1
logger.Debug("Loading relation part [%d/%d]: %s", i+1, len(relationParts), relationPart)
// Check if this level is already loaded by Bun (avoid duplicates)
existingRecords, alreadyLoaded := checkIfRelationAlreadyLoaded(currentRecords, relationPart)
if alreadyLoaded && existingRecords.IsValid() && existingRecords.Len() > 0 {
logger.Info("Relation '%s' already loaded by Bun, using existing %d records", relationPart, existingRecords.Len())
currentRecords = existingRecords
continue
}
// Load this level and get the loaded records for the next level
loadedRecords, err := b.loadRelationLevel(ctx, currentRecords, relationPart, isLastPart, applyFuncs)
if err != nil {
return fmt.Errorf("failed to load relation %s (part %s): %w", relation, relationPart, err)
}
// For nested relations, use the loaded records as parents for the next level
if !isLastPart && loadedRecords.IsValid() && loadedRecords.Len() > 0 {
logger.Debug("Collected %d records for next level", loadedRecords.Len())
currentRecords = loadedRecords
} else if !isLastPart {
logger.Debug("No records loaded at level %s, stopping nested preload", relationPart)
break
}
}
}
return nil
}
// loadRelationLevel loads a single level of a relation for a set of parent records
// Returns the loaded records (for use as parents in nested preloads) and any error
func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords reflect.Value, relationName string, isLast bool, applyFuncs []func(common.SelectQuery) common.SelectQuery) (reflect.Value, error) {
if parentRecords.Len() == 0 {
return reflect.Value{}, nil
}
// Get the first record to inspect the struct type
firstRecord := parentRecords.Index(0)
if firstRecord.Kind() == reflect.Ptr {
firstRecord = firstRecord.Elem()
}
if firstRecord.Kind() != reflect.Struct {
return reflect.Value{}, fmt.Errorf("expected struct, got %v", firstRecord.Kind())
}
parentType := firstRecord.Type()
// Find the relation field in the struct
structField, found := parentType.FieldByName(relationName)
if !found {
return reflect.Value{}, fmt.Errorf("relation field %s not found in struct %s", relationName, parentType.Name())
}
// Parse the bun tag to get relation info
bunTag := structField.Tag.Get("bun")
logger.Debug("Relation %s bun tag: %s", relationName, bunTag)
relInfo, err := parseRelationTag(bunTag)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to parse relation tag for %s: %w", relationName, err)
}
logger.Debug("Parsed relation: type=%s, join=%s", relInfo.relType, relInfo.joinCondition)
// Extract foreign key values from parent records
fkValues, err := extractForeignKeyValues(parentRecords, relInfo.localKey)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to extract FK values: %w", err)
}
if len(fkValues) == 0 {
logger.Debug("No foreign key values to load for relation %s", relationName)
return reflect.Value{}, nil
}
logger.Debug("Loading %d related records for %s (FK values: %v)", len(fkValues), relationName, fkValues)
// Get the related model type
relatedType := structField.Type
isSlice := relatedType.Kind() == reflect.Slice
if isSlice {
relatedType = relatedType.Elem()
}
if relatedType.Kind() == reflect.Ptr {
relatedType = relatedType.Elem()
}
// Create a slice to hold the results
resultsSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(relatedType)), 0, len(fkValues))
resultsPtr := reflect.New(resultsSlice.Type())
resultsPtr.Elem().Set(resultsSlice)
// Build and execute the query
query := b.db.NewSelect().Model(resultsPtr.Interface())
// Apply WHERE clause: foreign_key IN (values...)
query = query.Where(fmt.Sprintf("%s IN (?)", relInfo.foreignKey), bun.In(fkValues))
// Apply user's functions (if any)
if isLast && len(applyFuncs) > 0 {
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
for _, fn := range applyFuncs {
if fn != nil {
wrapper = fn(wrapper).(*BunSelectQuery)
query = wrapper.query
}
}
}
// Execute the query
err = query.Scan(ctx)
if err != nil {
return reflect.Value{}, fmt.Errorf("failed to load related records: %w", err)
}
loadedRecords := resultsPtr.Elem()
logger.Info("Loaded %d related records for relation %s", loadedRecords.Len(), relationName)
// Associate loaded records back to parent records
err = associateRelatedRecords(parentRecords, loadedRecords, relationName, relInfo, isSlice)
if err != nil {
return reflect.Value{}, err
}
// Return the loaded records for use in nested preloads
return loadedRecords, nil
}
// relationInfo holds parsed relation metadata
type relationInfo struct {
relType string // has-one, has-many, belongs-to
localKey string // Key in parent table
foreignKey string // Key in related table
joinCondition string // Full join condition
}
// parseRelationTag parses the bun:"rel:..." tag
func parseRelationTag(tag string) (*relationInfo, error) {
info := &relationInfo{}
// Parse tag like: rel:has-one,join:rid_mastertaskitem=rid_mastertaskitem
parts := strings.Split(tag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "rel:") {
info.relType = strings.TrimPrefix(part, "rel:")
} else if strings.HasPrefix(part, "join:") {
info.joinCondition = strings.TrimPrefix(part, "join:")
// Parse join: local_key=foreign_key
joinParts := strings.Split(info.joinCondition, "=")
if len(joinParts) == 2 {
info.localKey = strings.TrimSpace(joinParts[0])
info.foreignKey = strings.TrimSpace(joinParts[1])
}
}
}
if info.relType == "" || info.localKey == "" || info.foreignKey == "" {
return nil, fmt.Errorf("incomplete relation tag: %s", tag)
}
return info, nil
}
// extractForeignKeyValues collects FK values from parent records
func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]interface{}, error) {
values := make([]interface{}, 0, records.Len())
seenValues := make(map[interface{}]bool)
for i := 0; i < records.Len(); i++ {
record := records.Index(i)
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Find the FK field - try both exact name and capitalized version
fkField := record.FieldByName(fkFieldName)
if !fkField.IsValid() {
// Try capitalized version
fkField = record.FieldByName(strings.ToUpper(fkFieldName[:1]) + fkFieldName[1:])
}
if !fkField.IsValid() {
// Try finding by json tag
for j := 0; j < record.NumField(); j++ {
field := record.Type().Field(j)
jsonTag := field.Tag.Get("json")
bunTag := field.Tag.Get("bun")
if strings.HasPrefix(jsonTag, fkFieldName) || strings.Contains(bunTag, fkFieldName) {
fkField = record.Field(j)
break
}
}
}
if !fkField.IsValid() {
continue // Skip records without FK
}
// Extract the value
var value interface{}
if fkField.CanInterface() {
value = fkField.Interface()
// Handle SqlNull types
if nullType, ok := value.(interface{ IsNull() bool }); ok {
if nullType.IsNull() {
continue
}
}
// Handle types with Int64() method
if int64er, ok := value.(interface{ Int64() int64 }); ok {
value = int64er.Int64()
}
// Deduplicate
if !seenValues[value] {
values = append(values, value)
seenValues[value] = true
}
}
}
return values, nil
}
// associateRelatedRecords associates loaded records back to parents
func associateRelatedRecords(parents, related reflect.Value, fieldName string, relInfo *relationInfo, isSlice bool) error {
logger.Debug("Associating %d related records to %d parents for field '%s'", related.Len(), parents.Len(), fieldName)
// Build a map: foreignKey -> related record(s)
relatedMap := make(map[interface{}][]reflect.Value)
for i := 0; i < related.Len(); i++ {
relRecord := related.Index(i)
relRecordElem := relRecord
if relRecordElem.Kind() == reflect.Ptr {
relRecordElem = relRecordElem.Elem()
}
// Get the foreign key value from the related record - try multiple variations
fkField := findFieldByName(relRecordElem, relInfo.foreignKey)
if !fkField.IsValid() {
logger.Warn("Could not find FK field '%s' in related record type %s", relInfo.foreignKey, relRecordElem.Type().Name())
continue
}
fkValue := extractFieldValue(fkField)
if fkValue == nil {
continue
}
relatedMap[fkValue] = append(relatedMap[fkValue], related.Index(i))
}
logger.Debug("Built related map with %d unique FK values", len(relatedMap))
// Associate with parents
associatedCount := 0
for i := 0; i < parents.Len(); i++ {
parentPtr := parents.Index(i)
parent := parentPtr
if parent.Kind() == reflect.Ptr {
parent = parent.Elem()
}
// Get the local key value from parent
localField := findFieldByName(parent, relInfo.localKey)
if !localField.IsValid() {
logger.Warn("Could not find local key field '%s' in parent type %s", relInfo.localKey, parent.Type().Name())
continue
}
localValue := extractFieldValue(localField)
if localValue == nil {
continue
}
// Find matching related records
matches := relatedMap[localValue]
if len(matches) == 0 {
continue
}
// Set the relation field - IMPORTANT: use the pointer, not the elem
relationField := parent.FieldByName(fieldName)
if !relationField.IsValid() {
logger.Warn("Relation field '%s' not found in parent type %s", fieldName, parent.Type().Name())
continue
}
if !relationField.CanSet() {
logger.Warn("Relation field '%s' cannot be set (unexported?)", fieldName)
continue
}
if isSlice {
// For has-many: replace entire slice (don't append to avoid duplicates)
newSlice := reflect.MakeSlice(relationField.Type(), 0, len(matches))
for _, match := range matches {
newSlice = reflect.Append(newSlice, match)
}
relationField.Set(newSlice)
associatedCount += len(matches)
logger.Debug("Set has-many field '%s' with %d records for parent %d", fieldName, len(matches), i)
} else {
// For has-one/belongs-to: only set if not already set (avoid duplicates)
if relationField.IsNil() {
relationField.Set(matches[0])
associatedCount++
logger.Debug("Set has-one field '%s' for parent %d", fieldName, i)
} else {
logger.Debug("Skipping has-one field '%s' for parent %d (already set)", fieldName, i)
}
}
}
logger.Info("Associated %d related records to %d parents for field '%s'", associatedCount, parents.Len(), fieldName)
return nil
}
// findFieldByName finds a struct field by name, trying multiple variations
func findFieldByName(v reflect.Value, name string) reflect.Value {
// Try exact name
field := v.FieldByName(name)
if field.IsValid() {
return field
}
// Try with capital first letter
if len(name) > 0 {
capital := strings.ToUpper(name[0:1]) + name[1:]
field = v.FieldByName(capital)
if field.IsValid() {
return field
}
}
// Try searching by json or bun tag
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
jsonTag := f.Tag.Get("json")
bunTag := f.Tag.Get("bun")
// Check json tag
if strings.HasPrefix(jsonTag, name+",") || jsonTag == name {
return v.Field(i)
}
// Check bun tag for column name
if strings.Contains(bunTag, name+",") || strings.Contains(bunTag, name+":") {
return v.Field(i)
}
}
return reflect.Value{}
}
// extractFieldValue extracts the value from a field, handling SqlNull types
func extractFieldValue(field reflect.Value) interface{} {
if !field.CanInterface() {
return nil
}
value := field.Interface()
// Handle SqlNull types
if nullType, ok := value.(interface{ IsNull() bool }); ok {
if nullType.IsNull() {
return nil
}
}
// Handle types with Int64() method
if int64er, ok := value.(interface{ Int64() int64 }); ok {
return int64er.Int64()
}
// Handle types with String() method for comparison
if stringer, ok := value.(interface{ String() string }); ok {
return stringer.String()
}
return value
}
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query // JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships // This is more efficient for many-to-one or one-to-one relationships
@@ -734,7 +1177,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
return fmt.Errorf("destination cannot be nil") return fmt.Errorf("destination cannot be nil")
} }
// Execute the main query first
err = b.query.Scan(ctx, dest) err = b.query.Scan(ctx, dest)
if err != nil { if err != nil {
// Log SQL string for debugging // Log SQL string for debugging
@@ -743,17 +1185,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
return err return err
} }
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
// Clear deferred preloads to prevent re-execution
b.deferredPreloads = nil
}
return nil return nil
} }
@@ -803,7 +1234,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
} }
} }
// Execute the main query first
err = b.query.Scan(ctx) err = b.query.Scan(ctx)
if err != nil { if err != nil {
// Log SQL string for debugging // Log SQL string for debugging
@@ -812,147 +1242,18 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return err return err
} }
// Execute any deferred preloads // After main query, load custom preloads using separate queries
if len(b.deferredPreloads) > 0 { if len(b.customPreloads) > 0 {
model := b.query.GetModel() logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
err = b.executeDeferredPreloads(ctx, model.Value()) if err := b.loadCustomPreloads(ctx); err != nil {
if err != nil { logger.Error("Failed to load custom preloads: %v", err)
logger.Warn("Failed to execute deferred preloads: %v", err) return err
// Don't fail the whole query, just log the warning
}
// Clear deferred preloads to prevent re-execution
b.deferredPreloads = nil
}
return nil
}
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
} }
} }
return nil return nil
} }
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get a pointer to the parent field so Bun can modify it
// CRITICAL: We need to pass a pointer, not a value, so that when Bun
// loads the child records and appends them to the slice, the changes
// are reflected in the original struct field.
var parentPtr interface{}
if parentField.Kind() == reflect.Ptr {
// Field is already a pointer (e.g., Parent *Parent), use as-is
parentPtr = parentField.Interface()
} else {
// Field is a value (e.g., Comments []Comment), get its address
if parentField.CanAddr() {
parentPtr = parentField.Addr().Interface()
} else {
return fmt.Errorf("cannot get address of field '%s'", parentRelation)
}
}
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
// CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent
// record, not the first parent in the database table.
return b.db.NewSelect().
Model(parentPtr).
WherePK().
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -1200,13 +1501,15 @@ func (b *BunResult) LastInsertId() (int64, error) {
// BunTxAdapter wraps a Bun transaction to implement the Database interface // BunTxAdapter wraps a Bun transaction to implement the Database interface
type BunTxAdapter struct { type BunTxAdapter struct {
tx bun.Tx tx bun.Tx
driverName string
} }
func (b *BunTxAdapter) NewSelect() common.SelectQuery { func (b *BunTxAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.tx.NewSelect(), query: b.tx.NewSelect(),
db: b.tx, db: b.tx,
driverName: b.driverName,
} }
} }
@@ -1250,3 +1553,7 @@ func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
func (b *BunTxAdapter) GetUnderlyingDB() interface{} { func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
return b.tx return b.tx
} }
func (b *BunTxAdapter) DriverName() string {
return b.driverName
}

View File

@@ -15,12 +15,16 @@ import (
// GormAdapter adapts GORM to work with our Database interface // GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct { type GormAdapter struct {
db *gorm.DB db *gorm.DB
driverName string
} }
// NewGormAdapter creates a new GORM adapter // NewGormAdapter creates a new GORM adapter
func NewGormAdapter(db *gorm.DB) *GormAdapter { func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db} adapter := &GormAdapter{db: db}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads // EnableQueryDebug enables query debugging which logs all SQL queries including preloads
@@ -40,7 +44,7 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
} }
func (g *GormAdapter) NewSelect() common.SelectQuery { func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db} return &GormSelectQuery{db: g.db, driverName: g.driverName}
} }
func (g *GormAdapter) NewInsert() common.InsertQuery { func (g *GormAdapter) NewInsert() common.InsertQuery {
@@ -79,7 +83,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if tx.Error != nil { if tx.Error != nil {
return nil, tx.Error return nil, tx.Error
} }
return &GormAdapter{db: tx}, nil return &GormAdapter{db: tx, driverName: g.driverName}, nil
} }
func (g *GormAdapter) CommitTx(ctx context.Context) error { func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -97,7 +101,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
} }
}() }()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx} adapter := &GormAdapter{db: tx, driverName: g.driverName}
return fn(adapter) return fn(adapter)
}) })
} }
@@ -106,12 +110,30 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db return g.db
} }
func (g *GormAdapter) DriverName() string {
if g.db.Dialector == nil {
return ""
}
// Normalize GORM's dialector name to match the project's canonical vocabulary.
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
switch name := g.db.Name(); name {
case "sqlserver":
return "mssql"
case "sqlite3":
return "sqlite"
default:
return name
}
}
// GormSelectQuery implements SelectQuery for GORM // GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct { type GormSelectQuery struct {
db *gorm.DB db *gorm.DB
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions joinTableAlias string // Alias to use for JOIN conditions
} }
@@ -123,7 +145,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName() fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table") // Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(fullTableName) // For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
} }
if provider, ok := model.(common.TableAliasProvider); ok { if provider, ok := model.(common.TableAliasProvider); ok {
@@ -136,7 +159,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
func (g *GormSelectQuery) Table(table string) common.SelectQuery { func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
// Check if the table name contains schema (e.g., "schema.table") // Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table) // For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(table, g.driverName)
return g return g
} }
@@ -322,7 +346,8 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
} }
wrapper := &GormSelectQuery{ wrapper := &GormSelectQuery{
db: db, db: db,
driverName: g.driverName,
} }
current := common.SelectQuery(wrapper) current := common.SelectQuery(wrapper)
@@ -360,6 +385,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
wrapper := &GormSelectQuery{ wrapper := &GormSelectQuery{
db: db, db: db,
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias joinTableAlias: strings.ToLower(relation), // Use relation name as alias
} }

View File

@@ -16,12 +16,19 @@ import (
// PgSQLAdapter adapts standard database/sql to work with our Database interface // PgSQLAdapter adapts standard database/sql to work with our Database interface
// This provides a lightweight PostgreSQL adapter without ORM overhead // This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct { type PgSQLAdapter struct {
db *sql.DB db *sql.DB
driverName string
} }
// NewPgSQLAdapter creates a new PostgreSQL adapter // NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter { // An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided;
return &PgSQLAdapter{db: db} // it defaults to "postgres" when omitted.
func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
name := "postgres"
if len(driverName) > 0 && driverName[0] != "" {
name = driverName[0]
}
return &PgSQLAdapter{db: db, driverName: name}
} }
// EnableQueryDebug enables query debugging for development // EnableQueryDebug enables query debugging for development
@@ -31,22 +38,25 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery { func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{ return &PgSQLSelectQuery{
db: p.db, db: p.db,
columns: []string{"*"}, driverName: p.driverName,
args: make([]interface{}, 0), columns: []string{"*"},
args: make([]interface{}, 0),
} }
} }
func (p *PgSQLAdapter) NewInsert() common.InsertQuery { func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{ return &PgSQLInsertQuery{
db: p.db, db: p.db,
values: make(map[string]interface{}), driverName: p.driverName,
values: make(map[string]interface{}),
} }
} }
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{ return &PgSQLUpdateQuery{
db: p.db, db: p.db,
driverName: p.driverName,
sets: make(map[string]interface{}), sets: make(map[string]interface{}),
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
@@ -56,6 +66,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{ return &PgSQLDeleteQuery{
db: p.db, db: p.db,
driverName: p.driverName,
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
} }
@@ -98,7 +109,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &PgSQLTxAdapter{tx: tx}, nil return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
} }
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error { func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
@@ -121,7 +132,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
return err return err
} }
adapter := &PgSQLTxAdapter{tx: tx} adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {
@@ -141,6 +152,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
return p.db return p.db
} }
func (p *PgSQLAdapter) DriverName() string {
return p.driverName
}
// preloadConfig represents a relationship to be preloaded // preloadConfig represents a relationship to be preloaded
type preloadConfig struct { type preloadConfig struct {
relation string relation string
@@ -165,6 +180,7 @@ type PgSQLSelectQuery struct {
model interface{} model interface{}
tableName string tableName string
tableAlias string tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string columns []string
columnExprs []string columnExprs []string
whereClauses []string whereClauses []string
@@ -183,7 +199,9 @@ type PgSQLSelectQuery struct {
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery { func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
p.model = model p.model = model
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName() fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
} }
if provider, ok := model.(common.TableAliasProvider); ok { if provider, ok := model.(common.TableAliasProvider); ok {
p.tableAlias = provider.TableAlias() p.tableAlias = provider.TableAlias()
@@ -192,7 +210,8 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
} }
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery { func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
p.tableName = table // For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p return p
} }
@@ -501,16 +520,19 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
// PgSQLInsertQuery implements InsertQuery for PostgreSQL // PgSQLInsertQuery implements InsertQuery for PostgreSQL
type PgSQLInsertQuery struct { type PgSQLInsertQuery struct {
db *sql.DB db *sql.DB
tx *sql.Tx tx *sql.Tx
tableName string tableName string
values map[string]interface{} driverName string
returning []string values map[string]interface{}
returning []string
} }
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery { func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName() fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
} }
// Extract values from model using reflection // Extract values from model using reflection
// This is a simplified implementation // This is a simplified implementation
@@ -518,7 +540,8 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
} }
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery { func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
p.tableName = table // For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p return p
} }
@@ -591,6 +614,7 @@ type PgSQLUpdateQuery struct {
db *sql.DB db *sql.DB
tx *sql.Tx tx *sql.Tx
tableName string tableName string
driverName string
model interface{} model interface{}
sets map[string]interface{} sets map[string]interface{}
whereClauses []string whereClauses []string
@@ -602,13 +626,16 @@ type PgSQLUpdateQuery struct {
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery { func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
p.model = model p.model = model
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName() fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
} }
return p return p
} }
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery { func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
p.tableName = table // For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
if p.model == nil { if p.model == nil {
model, err := modelregistry.GetModelByName(table) model, err := modelregistry.GetModelByName(table)
if err == nil { if err == nil {
@@ -749,6 +776,7 @@ type PgSQLDeleteQuery struct {
db *sql.DB db *sql.DB
tx *sql.Tx tx *sql.Tx
tableName string tableName string
driverName string
whereClauses []string whereClauses []string
args []interface{} args []interface{}
paramCounter int paramCounter int
@@ -756,13 +784,16 @@ type PgSQLDeleteQuery struct {
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery { func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName() fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
} }
return p return p
} }
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery { func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
p.tableName = table // For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p return p
} }
@@ -835,27 +866,31 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
// PgSQLTxAdapter wraps a PostgreSQL transaction // PgSQLTxAdapter wraps a PostgreSQL transaction
type PgSQLTxAdapter struct { type PgSQLTxAdapter struct {
tx *sql.Tx tx *sql.Tx
driverName string
} }
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery { func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{ return &PgSQLSelectQuery{
tx: p.tx, tx: p.tx,
columns: []string{"*"}, driverName: p.driverName,
args: make([]interface{}, 0), columns: []string{"*"},
args: make([]interface{}, 0),
} }
} }
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery { func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{ return &PgSQLInsertQuery{
tx: p.tx, tx: p.tx,
values: make(map[string]interface{}), driverName: p.driverName,
values: make(map[string]interface{}),
} }
} }
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{ return &PgSQLUpdateQuery{
tx: p.tx, tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}), sets: make(map[string]interface{}),
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
@@ -865,6 +900,7 @@ func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery { func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{ return &PgSQLDeleteQuery{
tx: p.tx, tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
} }
@@ -912,6 +948,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
return p.tx return p.tx
} }
func (p *PgSQLTxAdapter) DriverName() string {
return p.driverName
}
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy // applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
func (p *PgSQLSelectQuery) applyJoinPreloads() { func (p *PgSQLSelectQuery) applyJoinPreloads() {
for _, preload := range p.preloads { for _, preload := range p.preloads {
@@ -1036,9 +1076,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
// Create a new select query for the related table // Create a new select query for the related table
var db common.Database var db common.Database
if p.tx != nil { if p.tx != nil {
db = &PgSQLTxAdapter{tx: p.tx} db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
} else { } else {
db = &PgSQLAdapter{db: p.db} db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
} }
query := db.NewSelect(). query := db.NewSelect().

View File

@@ -11,15 +11,71 @@ import (
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver" "gorm.io/driver/sqlserver"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
const postgresIdentifierLimit = 63
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
// Returns true if the alias is likely to be truncated
func checkAliasLength(relation string) bool {
// Bun generates aliases like: parentalias__childalias__columnname
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
parts := strings.Split(relation, ".")
if len(parts) <= 1 {
return false // Single level relations are fine
}
// Calculate the actual alias prefix length that Bun will generate
// Bun uses double underscores (__) between each relation level
// and converts the relation names to lowercase with underscores
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
aliasPrefixLen := len(aliasPrefix)
// We need to add 2 more underscores for the column name separator plus column name length
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
// To be safe, assume the longest column name could be around 35 chars
maxColumnNameLen := 35
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
// Check if this would exceed PostgreSQL's identifier limit
if estimatedMaxLen > postgresIdentifierLimit {
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
return true
}
// Also check if just the prefix is getting close (within 15 chars of limit)
// This gives room for column names
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
relation, aliasPrefixLen, postgresIdentifierLimit)
return true
}
return false
}
// parseTableName splits a table name that may contain schema into separate schema and table // parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users") // For example: "public.users" -> ("public", "users")
// //
// "users" -> ("", "users") // "users" -> ("", "users")
func parseTableName(fullTableName string) (schema, table string) { //
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
// in the same way as PostgreSQL/MSSQL
func parseTableName(fullTableName, driverName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 { if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:] schema = fullTableName[:idx]
table = fullTableName[idx+1:]
// For SQLite, convert schema.table to schema_table
if driverName == "sqlite" || driverName == "sqlite3" {
table = schema + "_" + table
schema = ""
}
return schema, table
} }
return "", fullTableName return "", fullTableName
} }

View File

@@ -30,6 +30,12 @@ type Database interface {
// For Bun, this returns *bun.DB // For Bun, this returns *bun.DB
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN // This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
GetUnderlyingDB() interface{} GetUnderlyingDB() interface{}
// DriverName returns the canonical name of the underlying database driver.
// Possible values: "postgres", "sqlite", "mssql", "mysql".
// All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's
// "sqlserver") to the values above before returning.
DriverName() string
} }
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun) // SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)

View File

@@ -50,6 +50,9 @@ func (m *mockDatabase) RollbackTx(ctx context.Context) error {
func (m *mockDatabase) GetUnderlyingDB() interface{} { func (m *mockDatabase) GetUnderlyingDB() interface{} {
return nil return nil
} }
func (m *mockDatabase) DriverName() string {
return "postgres"
}
// Mock SelectQuery // Mock SelectQuery
type mockSelectQuery struct{} type mockSelectQuery struct{}

View File

@@ -0,0 +1,103 @@
package common
import (
"testing"
)
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
// are correctly handled when the tableName parameter matches the prefix
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "Correct table prefix should not be changed",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Wrong table prefix should be fixed",
where: "wrong_table.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Relation name should not replace correct table prefix",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: &RequestOptions{
Preload: []PreloadOption{
{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
},
},
},
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Unqualified column should remain unqualified",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "rid_parentmastertaskitem is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
// are correctly added to unqualified columns
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "Add prefix to unqualified column",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change already qualified column",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change qualified column with different table",
where: "other_table.rid_something is null",
tableName: "mastertaskitem",
expected: "other_table.rid_something is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AddTablePrefixToColumns(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -37,6 +37,7 @@ type Parameter struct {
type PreloadOption struct { type PreloadOption struct {
Relation string `json:"relation"` Relation string `json:"relation"`
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
Columns []string `json:"columns"` Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"` OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"` Sort []SortOption `json:"sort"`
@@ -49,9 +50,10 @@ type PreloadOption struct {
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters // Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
// Custom SQL JOINs from XFiles - used when preload needs additional joins // Custom SQL JOINs from XFiles - used when preload needs additional joins
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses

View File

@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
- **GORM** - Popular Go ORM - **GORM** - Popular Go ORM
- **Native** - Standard library `*sql.DB` - **Native** - Standard library `*sql.DB`
- All three share the same underlying connection pool - All three share the same underlying connection pool
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
- **Configuration-Driven**: YAML configuration with Viper integration - **Configuration-Driven**: YAML configuration with Viper integration
- **Production-Ready Features**: - **Production-Ready Features**:
- Automatic health checks and reconnection - Automatic health checks and reconnection
@@ -179,6 +180,35 @@ if err != nil {
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true) rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
``` ```
#### Cross-Database Example with SQLite
```go
// Same model works across all databases
type User struct {
ID int `bun:"id,pk"`
Username string `bun:"username"`
Email string `bun:"email"`
}
func (User) TableName() string {
return "auth.users"
}
// PostgreSQL connection
pgConn, _ := mgr.Get("primary")
pgDB, _ := pgConn.Bun()
var pgUsers []User
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
// Executes: SELECT * FROM auth.users
// SQLite connection
sqliteConn, _ := mgr.Get("cache-db")
sqliteDB, _ := sqliteConn.Bun()
var sqliteUsers []User
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
```
#### Use MongoDB #### Use MongoDB
```go ```go
@@ -368,6 +398,37 @@ Providers handle:
- Connection statistics - Connection statistics
- Connection cleanup - Connection cleanup
### SQLite Schema Handling
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
```go
// Model definition (works across all databases)
func (User) TableName() string {
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
// SQLite: "auth_users"
}
// Query execution
db.NewSelect().Model(&User{}).Scan(ctx)
// PostgreSQL/MSSQL: SELECT * FROM auth.users
// SQLite: SELECT * FROM auth_users
```
**How it Works**:
- Bun, GORM, and Native adapters detect the driver type
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
- Preload and relation queries are also handled automatically
**Benefits**:
- Write database-agnostic code
- Use the same models across PostgreSQL, MSSQL, and SQLite
- No conditional logic needed in your application
- Schema separation maintained through naming convention in SQLite
## Best Practices ## Best Practices
1. **Use Named Connections**: Be explicit about which database you're accessing 1. **Use Named Connections**: Be explicit about which database you're accessing

View File

@@ -467,13 +467,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
// Create a native adapter based on database type // Create a native adapter based on database type
switch c.dbType { switch c.dbType {
case DatabaseTypePostgreSQL: case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB) c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
case DatabaseTypeSQLite: case DatabaseTypeSQLite:
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
case DatabaseTypeMSSQL: case DatabaseTypeMSSQL:
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
default: default:
return nil, ErrUnsupportedDatabase return nil, ErrUnsupportedDatabase
} }

View File

@@ -231,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error {
// Close closes all database connections // Close closes all database connections
func (m *connectionManager) Close() error { func (m *connectionManager) Close() error {
// Stop the health checker before taking mu. performHealthCheck acquires
// a read lock, so waiting for the goroutine while holding the write lock
// would deadlock.
m.stopHealthChecker()
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// Stop health checker
m.stopHealthChecker()
// Close all connections // Close all connections
var errors []error var errors []error
for name, conn := range m.connections { for name, conn := range m.connections {

View File

@@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
return m return m
} }
func (m *MockDatabase) DriverName() string {
return "postgres"
}
// MockResult implements common.Result interface for testing // MockResult implements common.Result interface for testing
type MockResult struct { type MockResult struct {
rows int64 rows int64

View File

@@ -645,11 +645,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
// Database operation helpers (adapted from websocketspec) // Database operation helpers (adapted from websocketspec)
func (h *Handler) getTableName(schema, entity string, model interface{}) string { func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Use entity as table name
tableName := entity tableName := entity
if schema != "" { if schema != "" {
tableName = schema + "." + tableName if h.db.DriverName() == "sqlite" {
tableName = schema + "_" + tableName
} else {
tableName = schema + "." + tableName
}
} }
return tableName return tableName
} }

572
pkg/resolvespec/EXAMPLES.md Normal file
View File

@@ -0,0 +1,572 @@
# ResolveSpec Query Features Examples
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
## OR Logic in Filters (SearchOr)
### Basic OR Filter Example
Find all users with status "active" OR "pending":
```json
POST /users
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
}
]
}
}
```
### Combined AND/OR Filters
Find users with (status="active" OR status="pending") AND age >= 18:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "age",
"operator": "gte",
"value": 18
}
]
}
}
```
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
**Important Notes:**
- By default, filters use AND logic
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
- This grouping ensures OR conditions don't interfere with AND conditions
- You don't need to specify `"logic_operator": "AND"` as it's the default
### Multiple OR Groups
You can have multiple separate OR groups:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "priority",
"operator": "eq",
"value": "high"
},
{
"column": "priority",
"operator": "eq",
"value": "urgent",
"logic_operator": "OR"
}
]
}
}
```
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
## Custom Operators
### Simple Custom SQL Condition
Filter by email domain using custom SQL:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "company_emails",
"sql": "email LIKE '%@company.com'"
}
]
}
}
```
### Multiple Custom Operators
Combine multiple custom SQL conditions:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "recent_active",
"sql": "last_login > NOW() - INTERVAL '30 days'"
},
{
"name": "high_score",
"sql": "score > 1000"
}
]
}
}
```
### Complex Custom Operator
Use complex SQL expressions:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "priority_users",
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
}
]
}
}
```
### Combining Custom Operators with Regular Filters
Mix custom operators with standard filters:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "country",
"operator": "eq",
"value": "USA"
}
],
"customOperators": [
{
"name": "active_last_month",
"sql": "last_activity > NOW() - INTERVAL '1 month'"
}
]
}
}
```
## Row Numbers
### Two Ways to Get Row Numbers
There are two different features for row numbers:
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
2. **`RowNumber` field in models** - Automatically number all records in the response
### 1. FetchRowNumber - Get Position of Specific Record
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
```json
{
"operation": "read",
"options": {
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response - Contains ONLY the specified user:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "Alice Smith",
"score": 9850,
"level": 42
},
"metadata": {
"total": 10000,
"count": 1,
"filtered": 10000,
"row_number": 42
}
}
```
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
### Row Number with Filters
Find position within a filtered subset (e.g., "What's my rank in my country?"):
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "country",
"operator": "eq",
"value": "USA"
},
{
"column": "status",
"operator": "eq",
"value": "active"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "Bob Johnson",
"country": "USA",
"score": 7200,
"status": "active"
},
"metadata": {
"total": 2500,
"count": 1,
"filtered": 2500,
"row_number": 156
}
}
```
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
### 2. RowNumber Field - Auto-Number All Records
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
**Model Definition:**
```go
type Player struct {
ID int64 `json:"id"`
Name string `json:"name"`
Score int64 `json:"score"`
RowNumber int64 `json:"row_number"` // Will be auto-populated
}
```
**Request (with pagination):**
```json
{
"operation": "read",
"options": {
"sort": [{"column": "score", "direction": "desc"}],
"limit": 10,
"offset": 20
}
}
```
**Response - RowNumber automatically set:**
```json
{
"success": true,
"data": [
{
"id": 456,
"name": "Player21",
"score": 8900,
"row_number": 21
},
{
"id": 789,
"name": "Player22",
"score": 8850,
"row_number": 22
},
{
"id": 123,
"name": "Player23",
"score": 8800,
"row_number": 23
}
// ... records 24-30 ...
]
}
```
**How It Works:**
- `row_number = offset + index + 1` (1-based)
- With offset=20, first record gets row_number=21
- With offset=20, second record gets row_number=22
- Perfect for displaying "Rank" in paginated tables
**Use Case:** Displaying leaderboards with rank numbers:
```
Rank | Player | Score
-----|-----------|-------
21 | Player21 | 8900
22 | Player22 | 8850
23 | Player23 | 8800
```
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
### When to Use Each Feature
| Feature | Use Case | Returns | Performance |
|---------|----------|---------|-------------|
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
**Combined Example - Full Leaderboard UI:**
```javascript
// Request 1: Get current user's rank
const userRank = await api.read({
fetch_row_number: currentUserId,
sort: [{column: "score", direction: "desc"}]
});
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
// Request 2: Get top 10 with rank numbers
const top10 = await api.read({
sort: [{column: "score", direction: "desc"}],
limit: 10,
offset: 0
});
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
// Display:
// "Your Rank: #156"
// "Top Players:"
// "#1 - Alice - 9999"
// "#2 - Bob - 9876"
// ...
```
## Complete Example: Advanced Query
Combine all features for a complex query:
```json
{
"operation": "read",
"options": {
"columns": ["id", "name", "email", "score", "status"],
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "trial",
"logic_operator": "OR"
},
{
"column": "score",
"operator": "gte",
"value": 100
}
],
"customOperators": [
{
"name": "recent_activity",
"sql": "last_login > NOW() - INTERVAL '7 days'"
},
{
"name": "verified_email",
"sql": "email_verified = true"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
},
{
"column": "created_at",
"direction": "asc"
}
],
"fetch_row_number": "12345",
"limit": 50,
"offset": 0
}
}
```
This query:
- Selects specific columns
- Filters for users with status "active" OR "trial"
- AND score >= 100
- Applies custom SQL conditions for recent activity and verified emails
- Sorts by score (descending) then creation date (ascending)
- Returns the row number of user "12345" in this filtered/sorted set
- Returns 50 records starting from the first one
## Use Cases
### 1. Leaderboards - Get Current User's Rank
Get the current user's position and data (returns only their record):
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "game_id",
"operator": "eq",
"value": "game123"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "current_user_id"
}
}
```
**Tip:** For full leaderboards, make two requests:
1. One with `fetch_row_number` to get user's rank
2. One with `limit` and `offset` to get top players list
### 2. Multi-Status Search
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "order_status",
"operator": "eq",
"value": "pending"
},
{
"column": "order_status",
"operator": "eq",
"value": "processing",
"logic_operator": "OR"
},
{
"column": "order_status",
"operator": "eq",
"value": "shipped",
"logic_operator": "OR"
}
]
}
}
```
### 3. Advanced Date Filtering
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "this_month",
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
},
{
"name": "business_hours",
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
}
]
}
}
```
## Security Considerations
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
1. **Never** directly interpolate user input into custom operator SQL
2. Always validate and sanitize custom operator SQL on the backend
3. Consider using a whitelist of allowed custom operators
4. Use prepared statements or parameterized queries when possible
5. Implement proper authorization checks before executing queries
Example of safe custom operator handling in Go:
```go
// Whitelist of allowed custom operators
allowedOperators := map[string]string{
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
"premium_only": "subscription_level = 'premium'",
}
// Validate custom operators from request
for _, op := range req.Options.CustomOperators {
if sql, ok := allowedOperators[op.Name]; ok {
op.SQL = sql // Use whitelisted SQL
} else {
return errors.New("custom operator not allowed: " + op.Name)
}
}
```

View File

@@ -214,6 +214,146 @@ Content-Type: application/json
```json ```json
{ {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "age",
"operator": "gte",
"value": 18
}
]
}
```
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
### Custom Operators
Add custom SQL conditions when needed:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "email_domain_filter",
"sql": "LOWER(email) LIKE '%@example.com'"
},
{
"name": "recent_records",
"sql": "created_at > NOW() - INTERVAL '7 days'"
}
]
}
}
```
Custom operators are applied as additional WHERE conditions to your query.
### Fetch Row Number
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response - Returns ONLY the specified record with its position:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "John Doe",
"score": 850,
"status": "active"
},
"metadata": {
"total": 1000,
"count": 1,
"filtered": 1000,
"row_number": 42
}
}
```
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
## Preloading
Load related entities with custom configuration:
```json
{
"operation": "read",
"options": {
"columns": ["id", "name", "email"],
"preload": [
{
"relation": "posts",
"columns": ["id", "title", "created_at"],
"filters": [
{
"column": "status",
"operator": "eq",
"value": "published"
}
],
"sort": [
{
"column": "created_at",
"direction": "desc"
}
],
"limit": 5
},
{
"relation": "profile",
"columns": ["bio", "website"]
}
]
}
}
```
## Cursor Pagination
Efficient pagination for large datasets:
### First Request (No Cursor) ### First Request (No Cursor)
```json ```json
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
// Check permissions // Check permissions
if !userHasPermission(ctx.Context, ctx.Entity) { if !userHasPermission(ctx.Context, ctx.Entity) {
return fmt.Errorf("unauthorized access to %s", ctx.Entity) return fmt.Errorf("unauthorized access to %s", ctx.Entity)
return nil }
// Modify query options // Modify query options
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 { if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
} }
return nil return nil
users[i].Email = maskEmail(users[i].Email) })
}
// Register an after-read hook (e.g., for data transformation) // Register an after-read hook (e.g., for data transformation)
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error { handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
}) // Transform or filter results
if users, ok := ctx.Result.([]User); ok {
for i := range users {
users[i].Email = maskEmail(users[i].Email)
}
}
return nil return nil
}) })
// Register a before-create hook (e.g., for validation) // Register a before-create hook (e.g., for validation)
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error { handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
// Validate data // Validate data
if user, ok := ctx.Data.(*User); ok {
if user.Email == "" {
return fmt.Errorf("email is required") return fmt.Errorf("email is required")
} }
// Add timestamps // Add timestamps

View File

@@ -0,0 +1,143 @@
package resolvespec
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
h := &Handler{}
tests := []struct {
name string
filter common.FilterOption
expectedCondition string
expectedArgsCount int
}{
{
name: "Equal operator",
filter: common.FilterOption{
Column: "status",
Operator: "eq",
Value: "active",
},
expectedCondition: "status = ?",
expectedArgsCount: 1,
},
{
name: "Greater than operator",
filter: common.FilterOption{
Column: "age",
Operator: "gt",
Value: 18,
},
expectedCondition: "age > ?",
expectedArgsCount: 1,
},
{
name: "IN operator",
filter: common.FilterOption{
Column: "status",
Operator: "in",
Value: []string{"active", "pending"},
},
expectedCondition: "status IN (?)",
expectedArgsCount: 1,
},
{
name: "LIKE operator",
filter: common.FilterOption{
Column: "email",
Operator: "like",
Value: "%@example.com",
},
expectedCondition: "email LIKE ?",
expectedArgsCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
condition, args := h.buildFilterCondition(tt.filter)
if condition != tt.expectedCondition {
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
}
if len(args) != tt.expectedArgsCount {
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
}
// Note: Skip value comparison for slices as they can't be compared with ==
// The important part is that args are populated correctly
})
}
}
// TestORGrouping tests that consecutive OR filters are properly grouped
func TestORGrouping(t *testing.T) {
// This is a conceptual test - in practice we'd need a mock SelectQuery
// to verify the actual SQL grouping behavior
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
{Column: "age", Operator: "gte", Value: 18},
}
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
// The first three filters should be grouped together
// The fourth filter should be separate with AND
// Count OR groups
orGroupCount := 0
inORGroup := false
for i := 1; i < len(filters); i++ {
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
orGroupCount++
inORGroup = true
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
inORGroup = false
}
}
// We should have detected one OR group
if orGroupCount != 1 {
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
}
})
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
{Column: "priority", Operator: "eq", Value: "high"},
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
}
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
// Should have two OR groups
orGroupCount := 0
inORGroup := false
for i := 1; i < len(filters); i++ {
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
orGroupCount++
inORGroup = true
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
inORGroup = false
}
}
// We should have detected two OR groups
if orGroupCount != 2 {
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
}
})
}

View File

@@ -280,10 +280,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
} }
// Apply filters // Apply filters with proper grouping for OR logic
for _, filter := range options.Filters { query = h.applyFilters(query, options.Filters)
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
query = h.applyFilter(query, filter) // Apply custom operators
for _, customOp := range options.CustomOperators {
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
query = query.Where(customOp.SQL)
} }
// Apply sorting // Apply sorting
@@ -381,24 +384,105 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
} }
// Apply pagination // Handle FetchRowNumber if requested
if options.Limit != nil && *options.Limit > 0 { var rowNumber *int64
logger.Debug("Applying limit: %d", *options.Limit) if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
query = query.Limit(*options.Limit) logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
pkName := reflection.GetPrimaryKeyName(model)
// Build ROW_NUMBER window function SQL
rowNumberSQL := "ROW_NUMBER() OVER ("
if len(options.Sort) > 0 {
rowNumberSQL += "ORDER BY "
for i, sort := range options.Sort {
if i > 0 {
rowNumberSQL += ", "
}
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
}
}
rowNumberSQL += ")"
// Create a query to fetch the row number using a subquery approach
// We'll select the PK and row_number, then filter by the target ID
type RowNumResult struct {
RowNum int64 `bun:"row_num"`
}
rowNumQuery := h.db.NewSelect().Table(tableName).
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
Column(pkName)
// Apply the same filters as the main query
for _, filter := range options.Filters {
rowNumQuery = h.applyFilter(rowNumQuery, filter)
}
// Apply custom operators
for _, customOp := range options.CustomOperators {
rowNumQuery = rowNumQuery.Where(customOp.SQL)
}
// Filter for the specific ID we want the row number for
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
// Execute query to get row number
var result RowNumResult
if err := rowNumQuery.Scan(ctx, &result); err != nil {
if err == sql.ErrNoRows {
// Build filter description for error message
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
if len(options.CustomOperators) > 0 {
customOps := make([]string, 0, len(options.CustomOperators))
for _, op := range options.CustomOperators {
customOps = append(customOps, op.SQL)
}
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
}
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
} else {
logger.Warn("Error fetching row number: %v", err)
}
} else {
rowNumber = &result.RowNum
logger.Debug("Found row number: %d", *rowNumber)
}
} }
if options.Offset != nil && *options.Offset > 0 {
logger.Debug("Applying offset: %d", *options.Offset) // Apply pagination (skip if FetchRowNumber is set - we want only that record)
query = query.Offset(*options.Offset) if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
if options.Limit != nil && *options.Limit > 0 {
logger.Debug("Applying limit: %d", *options.Limit)
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
logger.Debug("Applying offset: %d", *options.Offset)
query = query.Offset(*options.Offset)
}
} }
// Execute query // Execute query
var result interface{} var result interface{}
if id != "" { if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
logger.Debug("Querying single record with ID: %s", id) // Single record query - either by URL ID or FetchRowNumber
var targetID string
if id != "" {
targetID = id
logger.Debug("Querying single record with URL ID: %s", id)
} else {
targetID = *options.FetchRowNumber
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
}
// For single record, create a new pointer to the struct type // For single record, create a new pointer to the struct type
singleResult := reflect.New(modelType).Interface() singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id) query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := query.Scan(ctx, singleResult); err != nil { if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err) logger.Error("Error querying record: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
@@ -418,20 +502,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
logger.Info("Successfully retrieved records") logger.Info("Successfully retrieved records")
// Build metadata
limit := 0 limit := 0
if options.Limit != nil {
limit = *options.Limit
}
offset := 0 offset := 0
if options.Offset != nil { count := int64(total)
offset = *options.Offset
// When FetchRowNumber is used, we only return 1 record
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
count = 1
// Set the fetched row number on the record
if rowNumber != nil {
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
}
} else {
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Set row numbers on records if RowNumber field exists
// Only for multiple records (not when fetching single record)
h.setRowNumbersOnRecords(result, offset)
} }
h.sendResponse(w, result, &common.Metadata{ h.sendResponse(w, result, &common.Metadata{
Total: int64(total), Total: int64(total),
Filtered: int64(total), Filtered: int64(total),
Limit: limit, Count: count,
Offset: offset, Limit: limit,
Offset: offset,
RowNumber: rowNumber,
}) })
} }
@@ -1303,29 +1406,161 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
h.sendResponse(w, recordToDelete, nil) h.sendResponse(w, recordToDelete, nil)
} }
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery { // applyFilters applies all filters with proper grouping for OR logic
// Groups consecutive OR filters together to ensure proper query precedence
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
// Check if this starts an OR group (current or next filter has OR logic)
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
// Collect all consecutive filters that are OR'd together
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
// Apply the OR group as a single grouped WHERE clause
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
// Single filter with AND logic (or first filter)
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
// applyFilterGroup applies a group of filters that should be OR'd together
// Always wraps them in parentheses and applies as a single WHERE clause
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
// Build all conditions and collect args
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
// Single filter - no need for grouping
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
// Multiple conditions - group with parentheses and OR
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
return query.Where(groupedCondition, args...)
}
// buildFilterCondition builds a filter condition and returns it with args
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
var condition string
var args []interface{}
switch filter.Operator { switch filter.Operator {
case "eq": case "eq":
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq": case "neq":
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt": case "gt":
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte": case "gte":
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt": case "lt":
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte": case "lte":
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like": case "like":
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s LIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "ilike": case "ilike":
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value) condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in": case "in":
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value) condition = fmt.Sprintf("%s IN (?)", filter.Column)
args = []interface{}{filter.Value}
default:
return "", nil
}
return condition, args
}
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
// Determine which method to use based on LogicOperator
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
var condition string
var args []interface{}
switch filter.Operator {
case "eq":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "ilike":
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in":
condition = fmt.Sprintf("%s IN (?)", filter.Column)
args = []interface{}{filter.Value}
default: default:
return query return query
} }
// Apply filter with appropriate logic operator
if useOrLogic {
return query.WhereOr(condition, args...)
}
return query.Where(condition, args...)
} }
// parseTableName splits a table name that may contain schema into separate schema and table // parseTableName splits a table name that may contain schema into separate schema and table
@@ -1380,10 +1615,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return schema, entity return schema, entity
} }
// getTableName returns the full table name including schema (schema.table) // getTableName returns the full table name including schema.
// For most drivers the result is "schema.table". For SQLite, which does not
// support schema-qualified names, the schema and table are joined with an
// underscore: "schema_table".
func (h *Handler) getTableName(schema, entity string, model interface{}) string { func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model) schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" { if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName) return fmt.Sprintf("%s.%s", schemaName, tableName)
} }
return tableName return tableName
@@ -1703,6 +1944,51 @@ func toSnakeCase(s string) string {
return strings.ToLower(result.String()) return strings.ToLower(result.String())
} }
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
// The row number is calculated as offset + index + 1 (1-based)
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
// Get the reflect value of the records
recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr {
recordsValue = recordsValue.Elem()
}
// Ensure it's a slice
if recordsValue.Kind() != reflect.Slice {
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
return
}
// Iterate through each record
for i := 0; i < recordsValue.Len(); i++ {
record := recordsValue.Index(i)
// Dereference if it's a pointer
if record.Kind() == reflect.Ptr {
if record.IsNil() {
continue
}
record = record.Elem()
}
// Ensure it's a struct
if record.Kind() != reflect.Struct {
continue
}
// Try to find and set the RowNumber field
rowNumberField := record.FieldByName("RowNumber")
if rowNumberField.IsValid() && rowNumberField.CanSet() {
// Check if the field is of type int64
if rowNumberField.Kind() == reflect.Int64 {
rowNum := int64(offset + i + 1)
rowNumberField.SetInt(rowNum)
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
}
}
}
}
// HandleOpenAPI generates and returns the OpenAPI specification // HandleOpenAPI generates and returns the OpenAPI specification
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) { func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
if h.openAPIGenerator == nil { if h.openAPIGenerator == nil {

View File

@@ -435,9 +435,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// Apply preloading // Apply preloading
logger.Debug("Total preloads to apply: %d", len(options.Preload))
for idx := range options.Preload { for idx := range options.Preload {
preload := options.Preload[idx] preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation) logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
// Validate and fix WHERE clause to ensure it contains the relation prefix // Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
@@ -547,8 +549,30 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
} }
// If ID is provided, filter by ID // Handle FetchRowNumber before applying ID filter
if id != "" { // This must happen before the query to get the row position, then filter by PK
var fetchedRowNumber *int64
var fetchRowNumberPKValue string
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
pkName := reflection.GetPrimaryKeyName(model)
fetchRowNumberPKValue = *options.FetchRowNumber
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
if err != nil {
logger.Error("Failed to fetch row number: %v", err)
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
return
}
fetchedRowNumber = &rowNum
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
// Now filter the main query to this specific primary key
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), fetchRowNumberPKValue)
} else if id != "" {
// If ID is provided (and not FetchRowNumber), filter by ID
pkName := reflection.GetPrimaryKeyName(model) pkName := reflection.GetPrimaryKeyName(model)
logger.Debug("Filtering by ID=%s: %s", pkName, id) logger.Debug("Filtering by ID=%s: %s", pkName, id)
@@ -728,7 +752,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// Set row numbers on each record if the model has a RowNumber field // Set row numbers on each record if the model has a RowNumber field
h.setRowNumbersOnRecords(modelPtr, offset) // If FetchRowNumber was used, set the fetched row number instead of offset-based
if fetchedRowNumber != nil {
// FetchRowNumber: set the actual row position on the record
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
} else {
h.setRowNumbersOnRecords(modelPtr, offset)
}
metadata := &common.Metadata{ metadata := &common.Metadata{
Total: int64(total), Total: int64(total),
@@ -738,21 +769,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset, Offset: offset,
} }
// Fetch row number for a specific record if requested // If FetchRowNumber was used, also set it in metadata
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" { if fetchedRowNumber != nil {
pkName := reflection.GetPrimaryKeyName(model) metadata.RowNumber = fetchedRowNumber
pkValue := *options.FetchRowNumber logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
if err != nil {
logger.Warn("Failed to fetch row number: %v", err)
// Don't fail the entire request, just log the warning
} else {
metadata.RowNumber = &rowNum
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
}
} }
// Execute AfterRead hooks // Execute AfterRead hooks
@@ -916,10 +936,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
// Build RequestOptions with all preloads to allow references to sibling relations // Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads} preloadOpts := &common.RequestOptions{Preload: allPreloads}
// First add table prefixes to unqualified columns
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) // Determine the table name to use for WHERE clause processing
// Then sanitize and allow preload table prefixes // Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) tableName := preload.TableName
if tableName == "" {
tableName = reflection.ExtractTableNameOnly(preload.Relation)
}
// In Bun's Relation context, table prefixes are only needed when there are JOINs
// Without JOINs, Bun already knows which table is being queried
whereClause := preload.Where
if len(preload.SqlJoins) > 0 {
// Has JOINs: add table prefixes to disambiguate columns
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
}
// Sanitize the WHERE clause and allow preload table prefixes
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
if len(sanitizedWhere) > 0 { if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere) sq = sq.Where(sanitizedWhere)
} }
@@ -938,21 +973,82 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
}) })
// Handle recursive preloading // Handle recursive preloading
if preload.Recursive && depth < 4 { if preload.Recursive && depth < 8 {
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1) logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
// For recursive relationships, we need to get the last part of the relation path
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
relationParts := strings.Split(preload.Relation, ".") relationParts := strings.Split(preload.Relation, ".")
lastRelationName := relationParts[len(relationParts)-1] lastRelationName := relationParts[len(relationParts)-1]
// Create a recursive preload with the same configuration // Generate FK-based relation name for children
// but with the relation path extended // Use RecursiveChildKey if available, otherwise fall back to RelatedKey
recursivePreload := preload recursiveFK := preload.RecursiveChildKey
recursivePreload.Relation = preload.Relation + "." + lastRelationName if recursiveFK == "" {
recursiveFK = preload.RelatedKey
}
// Recursively apply preload until we reach depth 5 recursiveRelationName := lastRelationName
if recursiveFK != "" {
// Check if the last relation name already contains the FK suffix
// (this happens when XFiles already generated the FK-based name)
fkUpper := strings.ToUpper(recursiveFK)
expectedSuffix := "_" + fkUpper
if strings.HasSuffix(lastRelationName, expectedSuffix) {
// Already has FK suffix, just reuse the same name
recursiveRelationName = lastRelationName
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
} else {
// Generate FK-based name
recursiveRelationName = lastRelationName + expectedSuffix
keySource := "RelatedKey"
if preload.RecursiveChildKey != "" {
keySource = "RecursiveChildKey"
}
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
keySource, recursiveRelationName, recursiveFK)
}
} else {
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
preload.Relation, preload.Relation, lastRelationName)
}
// Create recursive preload
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
recursivePreload.Recursive = false // Prevent infinite recursion at this level
// Use the recursive FK for child relations, not the parent's RelatedKey
if preload.RecursiveChildKey != "" {
recursivePreload.RelatedKey = preload.RecursiveChildKey
}
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
recursivePreload.Where = ""
recursivePreload.Filters = []common.FilterOption{}
logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d",
recursivePreload.Relation, depth+1)
// Apply recursively up to depth 8
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1) query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
// ALSO: Extend any child relations (like DEF) to recursive levels
baseRelation := preload.Relation + "."
for i := range allPreloads {
relatedPreload := allPreloads[i]
if strings.HasPrefix(relatedPreload.Relation, baseRelation) &&
!strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") {
childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation)
extendedChildPreload := relatedPreload
extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName
extendedChildPreload.Recursive = false
logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d",
relatedPreload.Relation, extendedChildPreload.Relation, depth+1)
query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1)
}
}
} }
return query return query
@@ -1937,11 +2033,18 @@ func (h *Handler) processChildRelationsForField(
return nil return nil
} }
// getTableNameForRelatedModel gets the table name for a related model // getTableNameForRelatedModel gets the table name for a related model.
// If the model's TableName() is schema-qualified (e.g. "public.users") the
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string { func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
if provider, ok := model.(common.TableNameProvider); ok { if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName() tableName := provider.TableName()
if tableName != "" { if tableName != "" {
if schema, table := h.parseTableName(tableName); schema != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schema, table)
}
}
return tableName return tableName
} }
} }
@@ -2186,10 +2289,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return schema, entity return schema, entity
} }
// getTableName returns the full table name including schema (schema.table) // getTableName returns the full table name including schema.
// For most drivers the result is "schema.table". For SQLite, which does not
// support schema-qualified names, the schema and table are joined with an
// underscore: "schema_table".
func (h *Handler) getTableName(schema, entity string, model interface{}) string { func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model) schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" { if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName) return fmt.Sprintf("%s.%s", schemaName, tableName)
} }
return tableName return tableName
@@ -2511,21 +2620,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName) sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
} }
// Build WHERE clauses from filters // Build WHERE clause from filters with proper OR grouping
whereClauses := make([]string, 0) whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
for i := range options.Filters {
filter := &options.Filters[i]
whereClause := h.buildFilterSQL(filter, tableName)
if whereClause != "" {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
}
}
// Combine WHERE clauses
whereSQL := ""
if len(whereClauses) > 0 {
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
}
// Add custom SQL WHERE if provided // Add custom SQL WHERE if provided
if options.CustomSQLWhere != "" { if options.CustomSQLWhere != "" {
@@ -2573,19 +2669,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
var result []struct { var result []struct {
RN int64 `bun:"rn"` RN int64 `bun:"rn"`
} }
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
err := h.db.Query(ctx, &result, queryStr, pkValue) err := h.db.Query(ctx, &result, queryStr, pkValue)
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to fetch row number: %w", err) return 0, fmt.Errorf("failed to fetch row number: %w", err)
} }
if len(result) == 0 { if len(result) == 0 {
return 0, fmt.Errorf("no row found for primary key %s", pkValue) whereInfo := "none"
if whereSQL != "" {
whereInfo = whereSQL
}
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
} }
return result[0].RN, nil return result[0].RN, nil
} }
// buildFilterSQL converts a filter to SQL WHERE clause string // buildFilterSQL converts a filter to SQL WHERE clause string
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
// Groups consecutive OR filters together to ensure proper SQL precedence
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
if len(filters) == 0 {
return ""
}
var groups []string
i := 0
for i < len(filters) {
// Check if this starts an OR group (next filter has OR logic)
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
// Collect all consecutive filters that are OR'd together
orGroup := []string{}
// Add current filter
filterSQL := h.buildFilterSQL(&filters[i], tableName)
if filterSQL != "" {
orGroup = append(orGroup, filterSQL)
}
// Collect remaining OR filters
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
filterSQL := h.buildFilterSQL(&filters[j], tableName)
if filterSQL != "" {
orGroup = append(orGroup, filterSQL)
}
j++
}
// Group OR filters with parentheses
if len(orGroup) > 0 {
if len(orGroup) == 1 {
groups = append(groups, orGroup[0])
} else {
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
}
}
i = j
} else {
// Single filter with AND logic (or first filter)
filterSQL := h.buildFilterSQL(&filters[i], tableName)
if filterSQL != "" {
groups = append(groups, filterSQL)
}
i++
}
}
if len(groups) == 0 {
return ""
}
return "WHERE " + strings.Join(groups, " AND ")
}
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string { func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName) qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)

View File

@@ -48,7 +48,8 @@ type ExtendedRequestOptions struct {
AtomicTransaction bool AtomicTransaction bool
// X-Files configuration - comprehensive query options as a single JSON object // X-Files configuration - comprehensive query options as a single JSON object
XFiles *XFiles XFiles *XFiles
XFilesPresent bool // Flag to indicate if X-Files header was provided
} }
// ExpandOption represents a relation expansion configuration // ExpandOption represents a relation expansion configuration
@@ -274,7 +275,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
} }
// Resolve relation names (convert table names to field names) if model is provided // Resolve relation names (convert table names to field names) if model is provided
if model != nil { // Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
if model != nil && !options.XFilesPresent {
h.resolveRelationNamesInOptions(&options, model) h.resolveRelationNamesInOptions(&options, model)
} }
@@ -693,6 +695,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
// Store the original XFiles for reference // Store the original XFiles for reference
options.XFiles = &xfiles options.XFiles = &xfiles
options.XFilesPresent = true // Mark that X-Files header was provided
// Map XFiles fields to ExtendedRequestOptions // Map XFiles fields to ExtendedRequestOptions
@@ -984,11 +987,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
return return
} }
// Store the table name as-is for now - it will be resolved to field name later // Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
// when we have the model instance available // Fall back to TableName if Prefix is not specified
relationPath := xfile.TableName relationName := xfile.Prefix
if relationName == "" {
relationName = xfile.TableName
}
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
// Check if this is a self-referencing recursive relation (same table as parent)
// by comparing the last part of basePath with the current prefix
basePathParts := strings.Split(basePath, ".")
lastPrefix := basePathParts[len(basePathParts)-1]
if lastPrefix == relationName {
// This is a recursive self-reference, use FK-based name
fkUpper := strings.ToUpper(xfile.RelatedKey)
relationName = relationName + "_" + fkUpper
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
}
}
relationPath := relationName
if basePath != "" { if basePath != "" {
relationPath = basePath + "." + xfile.TableName relationPath = basePath + "." + relationName
} }
logger.Debug("X-Files: Adding preload for relation: %s", relationPath) logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
@@ -996,6 +1021,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Create PreloadOption from XFiles configuration // Create PreloadOption from XFiles configuration
preloadOpt := common.PreloadOption{ preloadOpt := common.PreloadOption{
Relation: relationPath, Relation: relationPath,
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
Columns: xfile.Columns, Columns: xfile.Columns,
OmitColumns: xfile.OmitColumns, OmitColumns: xfile.OmitColumns,
} }
@@ -1038,12 +1064,12 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Add WHERE clause if SQL conditions specified // Add WHERE clause if SQL conditions specified
whereConditions := make([]string, 0) whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 { if len(xfile.SqlAnd) > 0 {
// Process each SQL condition: add table prefixes and sanitize // Process each SQL condition
// Note: We don't add table prefixes here because they're only needed for JOINs
// The handler will add prefixes later if SqlJoins are present
for _, sqlCond := range xfile.SqlAnd { for _, sqlCond := range xfile.SqlAnd {
// First add table prefixes to unqualified columns // Sanitize the condition without adding prefixes
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName) sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
// Then sanitize the condition
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
if sanitizedCond != "" { if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond) whereConditions = append(whereConditions, sanitizedCond)
} }
@@ -1114,13 +1140,46 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath) logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
} }
// Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false
if len(xfile.ChildTables) > 0 {
for _, childTable := range xfile.ChildTables {
if childTable.Recursive && childTable.TableName == xfile.TableName {
hasRecursiveChild = true
preloadOpt.Recursive = true
preloadOpt.RecursiveChildKey = childTable.RelatedKey
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
relationPath, childTable.RelatedKey)
break
}
}
}
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
if xfile.Recursive && basePath != "" {
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
// Still process its parent/child tables for relations like DEF
h.processXFilesRelations(xfile, options, relationPath)
return
}
// Add the preload option // Add the preload option
options.Preload = append(options.Preload, preloadOpt) options.Preload = append(options.Preload, preloadOpt)
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
// Recursively process nested ParentTables and ChildTables // Recursively process nested ParentTables and ChildTables
if xfile.Recursive { // Skip processing child tables if we already detected and handled a recursive child
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath) if hasRecursiveChild {
h.processXFilesRelations(xfile, options, relationPath) logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
// But still process parent tables
if len(xfile.ParentTables) > 0 {
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
for _, parentTable := range xfile.ParentTables {
h.addXFilesPreload(parentTable, options, relationPath)
}
}
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 { } else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
h.processXFilesRelations(xfile, options, relationPath) h.processXFilesRelations(xfile, options, relationPath)
} }

View File

@@ -0,0 +1,110 @@
package restheadspec
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestPreloadOption_TableName verifies that TableName field is properly used
// when provided in PreloadOption for WHERE clause processing
func TestPreloadOption_TableName(t *testing.T) {
tests := []struct {
name string
preload common.PreloadOption
expectedTable string
}{
{
name: "TableName provided explicitly",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "mastertaskitem",
},
{
name: "TableName empty, should use empty string",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "",
},
{
name: "Simple relation without nested path",
preload: common.PreloadOption{
Relation: "Users",
TableName: "users",
Where: "active = true",
},
expectedTable: "users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that the TableName field stores the correct value
if tt.preload.TableName != tt.expectedTable {
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
}
// Verify that when TableName is provided, it should be used instead of extracting from relation
tableName := tt.preload.TableName
if tableName == "" {
// This simulates the fallback logic in handler.go
// In reality, reflection.ExtractTableNameOnly would be called
tableName = tt.expectedTable
}
if tableName != tt.expectedTable {
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
}
})
}
}
// TestXFilesPreload_StoresTableName verifies that XFiles processing
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
func TestXFilesPreload_StoresTableName(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "mastertaskitem",
Prefix: "MAL",
PrimaryKey: "rid_mastertaskitem",
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
Recursive: false, // Changed from true (recursive children are now skipped)
SqlAnd: []string{"rid_parentmastertaskitem is null"},
}
options := &ExtendedRequestOptions{}
// Process XFiles
handler.addXFilesPreload(xfiles, options, "MTL")
// Verify that a preload was added
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify the table name is stored
if preload.TableName != "mastertaskitem" {
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
}
// Verify the relation path includes the prefix
expectedRelation := "MTL.MAL"
if preload.Relation != expectedRelation {
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
}
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
expectedWhere := "rid_parentmastertaskitem is null"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
}
}

View File

@@ -0,0 +1,91 @@
package restheadspec
import (
"testing"
)
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
// to WHERE clauses when SqlJoins are present
func TestPreloadWhereClause_WithJoins(t *testing.T) {
tests := []struct {
name string
where string
sqlJoins []string
expectedPrefix bool
description string
}{
{
name: "No joins - no prefix needed",
where: "status = 'active'",
sqlJoins: []string{},
expectedPrefix: false,
description: "Without JOINs, Bun knows the table context",
},
{
name: "Has joins - prefix needed",
where: "status = 'active'",
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
expectedPrefix: true,
description: "With JOINs, table prefix disambiguates columns",
},
{
name: "Already has prefix - no change",
where: "users.status = 'active'",
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
expectedPrefix: true,
description: "Existing prefix should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This test documents the expected behavior
// The actual logic is in handler.go lines 916-937
hasJoins := len(tt.sqlJoins) > 0
if hasJoins != tt.expectedPrefix {
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
hasJoins, tt.expectedPrefix)
}
t.Logf("%s: %s", tt.name, tt.description)
})
}
}
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
// results in table prefixes being added to WHERE clauses
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "users",
Prefix: "USR",
PrimaryKey: "id",
SqlAnd: []string{"status = 'active'"},
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
}
options := &ExtendedRequestOptions{}
handler.addXFilesPreload(xfiles, options, "")
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify SqlJoins were stored
if len(preload.SqlJoins) != 1 {
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
}
// Verify WHERE clause does NOT have prefix yet (added later in handler)
expectedWhere := "status = 'active'"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
}
// Note: The handler will add the prefix when it sees SqlJoins
// This is tested in the handler itself, not during XFiles parsing
}

View File

@@ -0,0 +1,391 @@
//go:build !integration
// +build !integration
package restheadspec
import (
"context"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestRecursivePreloadClearsWhereClause tests that recursive preloads
// correctly clear the WHERE clause from the parent level to allow
// Bun to use foreign key relationships for loading children
func TestRecursivePreloadClearsWhereClause(t *testing.T) {
// Create a mock handler
handler := &Handler{}
// Create a preload option with a WHERE clause that filters root items
// This simulates the xfiles use case where the first level has a filter
// like "rid_parentmastertaskitem is null" to get root items
preload := common.PreloadOption{
Relation: "MastertaskItems",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
Where: "rid_parentmastertaskitem is null",
Filters: []common.FilterOption{
{
Column: "rid_parentmastertaskitem",
Operator: "is null",
Value: nil,
},
},
}
// Create a mock query that tracks operations
mockQuery := &mockSelectQuery{
operations: []string{},
}
// Apply the recursive preload at depth 0
// This should:
// 1. Apply the initial preload with the WHERE clause
// 2. Create a recursive preload without the WHERE clause
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
// Verify the mock query received the operations
mock := result.(*mockSelectQuery)
// Check that we have at least 2 PreloadRelation calls:
// 1. The initial "MastertaskItems" with WHERE clause
// 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause
preloadCount := 0
recursivePreloadFound := false
whereAppliedToRecursive := false
for _, op := range mock.operations {
if op == "PreloadRelation:MastertaskItems" {
preloadCount++
}
if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" {
recursivePreloadFound = true
}
// Check if WHERE was applied to the recursive preload (it shouldn't be)
if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound {
whereAppliedToRecursive = true
}
}
if preloadCount < 1 {
t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount)
}
if !recursivePreloadFound {
t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
}
if whereAppliedToRecursive {
t.Error("WHERE clause should not be applied to recursive preload levels")
}
}
// TestRecursivePreloadWithChildRelations tests that child relations
// (like DEF in MAL.DEF) are properly extended to recursive levels
func TestRecursivePreloadWithChildRelations(t *testing.T) {
handler := &Handler{}
// Create the main recursive preload
recursivePreload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
Where: "rid_parentmastertaskitem is null",
}
// Create a child relation that should be extended
childPreload := common.PreloadOption{
Relation: "MAL.DEF",
}
mockQuery := &mockSelectQuery{
operations: []string{},
}
allPreloads := []common.PreloadOption{recursivePreload, childPreload}
// Apply both preloads - the child preload should be extended when the recursive one processes
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0)
// Also need to apply the child preload separately (as would happen in normal flow)
result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Check that the child relation was extended to recursive levels
// We should see:
// - MAL (with WHERE)
// - MAL.DEF
// - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE)
// - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic)
foundMALDEF := false
foundRecursiveMAL := false
foundMALMALDEF := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.DEF" {
foundMALDEF = true
}
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundRecursiveMAL = true
}
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
foundMALMALDEF = true
}
}
if !foundMALDEF {
t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations)
}
if !foundRecursiveMAL {
t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
}
if !foundMALMALDEF {
t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations)
}
}
// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive
// preload generates the correct FK-based relation name using RelatedKey
func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) {
handler := &Handler{}
// Test case 1: With RelatedKey - should generate FK-based name
t.Run("WithRelatedKey", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Should generate MAL.MAL_RID_PARENTMASTERTASKITEM
foundCorrectRelation := false
foundIncorrectRelation := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundCorrectRelation = true
}
if op == "PreloadRelation:MAL.MAL" {
foundIncorrectRelation = true
}
}
if !foundCorrectRelation {
t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations)
}
if foundIncorrectRelation {
t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified")
}
})
// Test case 2: Without RelatedKey - should fallback to old behavior
t.Run("WithoutRelatedKey", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
// No RelatedKey
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Should fallback to MAL.MAL
foundFallback := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL" {
foundFallback = true
}
}
if !foundFallback {
t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations)
}
})
// Test case 3: Depth limit of 8
t.Run("DepthLimit", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
// Start at depth 7 - should create one more level
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
mock := result.(*mockSelectQuery)
foundDepth8 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth8 = true
}
}
if !foundDepth8 {
t.Error("Expected to create recursive level at depth 8")
}
// Start at depth 8 - should NOT create another level
mockQuery2 := &mockSelectQuery{operations: []string{}}
result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8)
mock2 := result2.(*mockSelectQuery)
foundDepth9 := false
for _, op := range mock2.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth9 = true
}
}
if foundDepth9 {
t.Error("Should NOT create recursive level beyond depth 8")
}
})
}
// mockSelectQuery implements common.SelectQuery for testing
type mockSelectQuery struct {
operations []string
}
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
m.operations = append(m.operations, "Model")
return m
}
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
m.operations = append(m.operations, "Table:"+table)
return m
}
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
for _, col := range columns {
m.operations = append(m.operations, "Column:"+col)
}
return m
}
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "ColumnExpr:"+query)
return m
}
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Where:"+query)
return m
}
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereOr:"+query)
return m
}
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereIn:"+column)
return m
}
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
m.operations = append(m.operations, "Order:"+order)
return m
}
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "OrderExpr:"+order)
return m
}
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
m.operations = append(m.operations, "Limit")
return m
}
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
m.operations = append(m.operations, "Offset")
return m
}
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Join:"+join)
return m
}
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "LeftJoin:"+join)
return m
}
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
m.operations = append(m.operations, "Group")
return m
}
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Having:"+query)
return m
}
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Preload:"+relation)
return m
}
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "PreloadRelation:"+relation)
// Apply the preload modifiers
for _, fn := range apply {
fn(m)
}
return m
}
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "JoinRelation:"+relation)
return m
}
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
m.operations = append(m.operations, "Scan")
return nil
}
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
m.operations = append(m.operations, "ScanModel")
return nil
}
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
m.operations = append(m.operations, "Count")
return 0, nil
}
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
m.operations = append(m.operations, "Exists")
return false, nil
}
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
return nil
}
func (m *mockSelectQuery) GetModel() interface{} {
return nil
}

View File

@@ -0,0 +1,527 @@
//go:build integration
// +build integration
package restheadspec
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockSelectQuery implements common.SelectQuery for testing (integration version)
type mockSelectQuery struct {
operations []string
}
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
m.operations = append(m.operations, "Model")
return m
}
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
m.operations = append(m.operations, "Table:"+table)
return m
}
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
for _, col := range columns {
m.operations = append(m.operations, "Column:"+col)
}
return m
}
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "ColumnExpr:"+query)
return m
}
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Where:"+query)
return m
}
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereOr:"+query)
return m
}
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereIn:"+column)
return m
}
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
m.operations = append(m.operations, "Order:"+order)
return m
}
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "OrderExpr:"+order)
return m
}
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
m.operations = append(m.operations, "Limit")
return m
}
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
m.operations = append(m.operations, "Offset")
return m
}
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Join:"+join)
return m
}
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "LeftJoin:"+join)
return m
}
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
m.operations = append(m.operations, "Group")
return m
}
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Having:"+query)
return m
}
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Preload:"+relation)
return m
}
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "PreloadRelation:"+relation)
// Apply the preload modifiers
for _, fn := range apply {
fn(m)
}
return m
}
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "JoinRelation:"+relation)
return m
}
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
m.operations = append(m.operations, "Scan")
return nil
}
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
m.operations = append(m.operations, "ScanModel")
return nil
}
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
m.operations = append(m.operations, "Count")
return 0, nil
}
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
m.operations = append(m.operations, "Exists")
return false, nil
}
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
return nil
}
func (m *mockSelectQuery) GetModel() interface{} {
return nil
}
// TestXFilesRecursivePreload is an integration test that validates the XFiles
// recursive preload functionality using real test data files.
//
// This test ensures:
// 1. XFiles request JSON is correctly parsed into PreloadOptions
// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM)
// 3. Parent WHERE clauses don't leak to child levels
// 4. Child relations (like DEF) are extended to all recursive levels
// 5. Hierarchical data structure matches expected output
func TestXFilesRecursivePreload(t *testing.T) {
// Load the XFiles request configuration
requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json")
requestData, err := os.ReadFile(requestPath)
require.NoError(t, err, "Failed to read xfiles.request.json")
var xfileConfig XFiles
err = json.Unmarshal(requestData, &xfileConfig)
require.NoError(t, err, "Failed to parse xfiles.request.json")
// Create handler and parse XFiles into PreloadOptions
handler := &Handler{}
options := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Preload: []common.PreloadOption{},
},
}
// Process the XFiles configuration - start with the root table
handler.processXFilesRelations(&xfileConfig, options, "")
// Verify that preload options were created
require.NotEmpty(t, options.Preload, "Expected preload options to be created")
// Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
break
}
}
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
// RelatedKey should be the parent relationship key (MTL -> MAL)
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
"Recursive preload should preserve original RelatedKey for parent relationship")
// RecursiveChildKey should be set from the recursive child config
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
"Recursive preload should have RecursiveChildKey set from recursive child config")
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
})
// Test 2: Verify mastertaskitem has WHERE clause for filtering root items
t.Run("RootLevelHasWhereClause", func(t *testing.T) {
var rootPreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL" {
rootPreload = preload
break
}
}
require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
})
// Test 3: Verify actiondefinition relation exists for mastertaskitem
t.Run("DEFRelationExists", func(t *testing.T) {
var defPreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL.DEF" {
defPreload = preload
break
}
}
require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem")
assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey,
"actiondefinition preload should have ForeignKey set")
})
// Test 4: Verify relation name generation with mock query
t.Run("RelationNameGeneration", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
found := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
found = true
break
}
}
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// Create mock query to track operations
mockQuery := &mockSelectQuery{operations: []string{}}
// Apply the recursive preload
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// Verify the correct FK-based relation name was generated
foundCorrectRelation := false
for _, op := range mock.operations {
// Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundCorrectRelation = true
}
}
assert.True(t, foundCorrectRelation,
"Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
mock.operations)
})
// Test 5: Verify WHERE clause is cleared for recursive levels
t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
found := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
found = true
break
}
}
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// The root level has a WHERE clause (rid_parentmastertaskitem is null)
// But when we apply recursion, it should be cleared
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// After the first level, WHERE clauses should not be reapplied
// We check that the recursive relation was created (which means WHERE was cleared internally)
foundRecursiveRelation := false
for _, op := range mock.operations {
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundRecursiveRelation = true
}
}
assert.True(t, foundRecursiveRelation,
"Recursive relation should be created (WHERE clause should be cleared internally)")
})
// Test 6: Verify child relations are extended to recursive levels
t.Run("ChildRelationsExtended", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
foundRecursive := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
foundRecursive = true
break
}
}
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// actiondefinition should be extended to the recursive level
// Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
foundExtendedDEF := false
for _, op := range mock.operations {
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
foundExtendedDEF = true
}
}
assert.True(t, foundExtendedDEF,
"Expected actiondefinition relation to be extended to recursive level. Operations: %v",
mock.operations)
})
}
// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8
func TestXFilesRecursivePreloadDepth(t *testing.T) {
handler := &Handler{}
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
allPreloads := []common.PreloadOption{preload}
t.Run("Depth7CreatesLevel8", func(t *testing.T) {
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
mock := result.(*mockSelectQuery)
foundDepth8 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth8 = true
}
}
assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7")
})
t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) {
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8)
mock := result.(*mockSelectQuery)
foundDepth9 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth9 = true
}
}
assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)")
})
}
// TestXFilesResponseStructure validates the actual structure of the response
// This test can be expanded when we have a full database integration test environment
func TestXFilesResponseStructure(t *testing.T) {
// Load the expected correct response
correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json")
correctData, err := os.ReadFile(correctResponsePath)
require.NoError(t, err, "Failed to read xfiles.response.correct.json")
var correctResponse []map[string]interface{}
err = json.Unmarshal(correctData, &correctResponse)
require.NoError(t, err, "Failed to parse xfiles.response.correct.json")
// Test 1: Verify root level has exactly 1 masterprocess
t.Run("RootLevelHasOneItem", func(t *testing.T) {
assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record")
})
// Test 2: Verify the root item has MTL relation
t.Run("RootHasMTLRelation", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, exists := rootItem["MTL"]
assert.True(t, exists, "Root item should have MTL relation")
assert.NotNil(t, mtl, "MTL relation should not be null")
})
// Test 3: Verify MTL has MAL items
t.Run("MTLHasMALItems", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, exists := firstMTL["MAL"]
assert.True(t, exists, "MTL item should have MAL relation")
assert.NotNil(t, mal, "MAL relation should not be null")
})
// Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive)
t.Run("MALHasRecursiveRelation", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
require.NotEmpty(t, mal, "MAL should have items")
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
// The key assertion: check for FK-based relation name
recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"]
assert.True(t, exists,
"MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)")
// It can be null or an array, depending on whether this item has children
if recursiveRelation != nil {
_, isArray := recursiveRelation.([]interface{})
assert.True(t, isArray,
"MAL_RID_PARENTMASTERTASKITEM should be an array when not null")
}
})
// Test 5: Verify "Receive COB Document for" appears as a child, not at root
t.Run("ChildItemsAreNested", func(t *testing.T) {
// This test verifies that "Receive COB Document for" doesn't appear
// multiple times at the wrong level, but is properly nested
// Count how many times we find this description at the MAL level (should be 0 or 1)
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
// Count root-level MAL items (before the fix, there were 12; should be 1)
assert.Len(t, mal, 1,
"MAL should have exactly 1 root-level item (before fix: 12 duplicates)")
// Verify the root item has a description
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
description, exists := firstMAL["description"]
assert.True(t, exists, "MAL item should have a description")
assert.Equal(t, "Capture COB Information", description,
"Root MAL item should be 'Capture COB Information'")
})
// Test 6: Verify DEF relation exists at MAL level
t.Run("DEFRelationExists", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
require.NotEmpty(t, mal, "MAL should have items")
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
// Verify DEF relation exists (child relation extension)
def, exists := firstMAL["DEF"]
assert.True(t, exists, "MAL item should have DEF relation")
// DEF can be null or an object
if def != nil {
_, isMap := def.(map[string]interface{})
assert.True(t, isMap, "DEF should be an object when not null")
}
})
}

527
pkg/security/OAUTH2.md Normal file
View File

@@ -0,0 +1,527 @@
# OAuth2 Authentication Guide
## Overview
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
## Features
- **Universal OAuth2 Support**: Works with any OAuth2 provider
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
- **Custom Providers**: Easy configuration for any OAuth2 service
- **Session Management**: Database-backed session storage
- **Token Refresh**: Automatic token refresh support
- **State Validation**: Built-in CSRF protection
- **User Auto-Creation**: Automatically creates users on first login
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
## Quick Start
### 1. Database Setup
```sql
-- Run the schema from database_schema.sql
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255),
user_level INTEGER DEFAULT 0,
roles VARCHAR(500),
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP,
remote_id VARCHAR(255),
auth_provider VARCHAR(50)
);
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
auth_provider VARCHAR(50)
);
-- OAuth2 stored procedures (7 functions)
-- See database_schema.sql for full implementation
```
### 2. Google OAuth2
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
// Create authenticator
oauth2Auth := security.NewGoogleAuthenticator(
"your-google-client-id",
"your-google-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
// Login route - redirects to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback route - handles Google response
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
```
### 3. GitHub OAuth2
```go
oauth2Auth := security.NewGitHubAuthenticator(
"your-github-client-id",
"your-github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
// Same routes pattern as Google
router.HandleFunc("/auth/github/login", ...)
router.HandleFunc("/auth/github/callback", ...)
```
### 4. Microsoft OAuth2
```go
oauth2Auth := security.NewMicrosoftAuthenticator(
"your-microsoft-client-id",
"your-microsoft-client-secret",
"http://localhost:8080/auth/microsoft/callback",
db,
)
```
### 5. Facebook OAuth2
```go
oauth2Auth := security.NewFacebookAuthenticator(
"your-facebook-client-id",
"your-facebook-client-secret",
"http://localhost:8080/auth/facebook/callback",
db,
)
```
## Custom OAuth2 Provider
```go
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
ClientID: "your-client-id",
ClientSecret: "your-client-secret",
RedirectURL: "http://localhost:8080/auth/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://your-provider.com/oauth/authorize",
TokenURL: "https://your-provider.com/oauth/token",
UserInfoURL: "https://your-provider.com/oauth/userinfo",
DB: db,
ProviderName: "custom",
// Optional: Custom user info parser
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
return &security.UserContext{
UserName: userInfo["username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["id"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo,
}, nil
},
})
```
## Protected Routes
```go
// Create security provider
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
// Apply middleware to protected routes
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(security.NewAuthMiddleware(securityList))
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
json.NewEncoder(w).Encode(userCtx)
})
```
## Token Refresh
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
```go
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google", "github", etc.
}
json.NewDecoder(r.Body).Decode(&req)
// Default to google if not specified
if req.Provider == "" {
req.Provider = "google"
}
// Use OAuth2-specific refresh method
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
json.NewEncoder(w).Encode(loginResp)
})
```
**Important Notes:**
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
- Store the refresh token securely on the client side
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
## Logout
```go
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
Token: userCtx.SessionID,
UserID: userCtx.UserID,
})
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
MaxAge: -1,
})
w.WriteHeader(http.StatusOK)
})
```
## Multi-Provider Setup
```go
// Single DatabaseAuthenticator with ALL OAuth2 providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(security.OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
})
// Get list of configured providers
providers := auth.OAuth2GetProviders() // ["google", "github"]
// Google routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
// ... handle response
})
// GitHub routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
// ... handle response
})
// Use same authenticator for protected routes - works for ALL providers
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
```
## Configuration Options
### OAuth2Config Fields
| Field | Type | Description |
|-------|------|-------------|
| ClientID | string | OAuth2 client ID from provider |
| ClientSecret | string | OAuth2 client secret |
| RedirectURL | string | Callback URL registered with provider |
| Scopes | []string | OAuth2 scopes to request |
| AuthURL | string | Provider's authorization endpoint |
| TokenURL | string | Provider's token endpoint |
| UserInfoURL | string | Provider's user info endpoint |
| DB | *sql.DB | Database connection for sessions |
| UserInfoParser | func | Custom parser for user info (optional) |
| StateValidator | func | Custom state validator (optional) |
| ProviderName | string | Provider name for logging (optional) |
## User Info Parsing
The default parser extracts these standard fields:
- `sub` → RemoteID
- `email` → Email, UserName
- `name` → UserName
- `login` → UserName (GitHub)
Custom parser example:
```go
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
// Extract custom fields
ctx := &security.UserContext{
UserName: userInfo["preferred_username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["sub"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo, // Store all claims
}
// Add custom roles based on provider data
if groups, ok := userInfo["groups"].([]interface{}); ok {
for _, g := range groups {
ctx.Roles = append(ctx.Roles, g.(string))
}
}
return ctx, nil
}
```
## Security Best Practices
1. **Always use HTTPS in production**
```go
http.SetCookie(w, &http.Cookie{
Secure: true, // Only send over HTTPS
HttpOnly: true, // Prevent XSS access
SameSite: http.SameSiteLaxMode, // CSRF protection
})
```
2. **Store secrets securely**
```go
clientID := os.Getenv("GOOGLE_CLIENT_ID")
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
```
3. **Validate redirect URLs**
- Only register trusted redirect URLs with OAuth2 providers
- Never accept redirect URL from request parameters
5. **Session expiration**
- OAuth2 sessions automatically expire based on token expiry
- Clean up expired sessions periodically:
```sql
DELETE FROM user_sessions WHERE expires_at < NOW();
```
4. **State parameter**
- Automatically generated with cryptographic randomness
- One-time use and expires after 10 minutes
- Prevents CSRF attacks
## Implementation Details
All database operations use stored procedures for consistency and security:
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
- `resolvespec_oauth_createsession` - Create OAuth2 session
- `resolvespec_oauth_getsession` - Validate and retrieve session
- `resolvespec_oauth_deletesession` - Logout/delete session
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
- `resolvespec_oauth_getuser` - Get user data by ID
## Provider Setup Guides
### Google
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create a new project or select existing
3. Enable Google+ API
4. Create OAuth 2.0 credentials
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
6. Copy Client ID and Client Secret
### GitHub
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
2. Click "New OAuth App"
3. Set Homepage URL: `http://localhost:8080`
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
5. Copy Client ID and Client Secret
### Microsoft
1. Go to [Azure Portal](https://portal.azure.com/)
2. Register new application in Azure AD
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
4. Create client secret
5. Copy Application (client) ID and secret value
### Facebook
1. Go to [Facebook Developers](https://developers.facebook.com/)
2. Create new app
3. Add Facebook Login product
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
5. Copy App ID and App Secret
## Troubleshooting
### "redirect_uri_mismatch" error
- Ensure the redirect URL in code matches exactly with provider configuration
- Include protocol (http/https), domain, port, and path
### "invalid_client" error
- Verify Client ID and Client Secret are correct
- Check if credentials are for the correct environment (dev/prod)
### "invalid_grant" error during token exchange
- State parameter validation failed
- Token might have expired
- Check server time synchronization
### User not created after successful OAuth2 login
- Check database constraints (username/email unique)
- Verify UserInfoParser is extracting required fields
- Check database logs for constraint violations
## Testing
```go
func TestOAuth2Flow(t *testing.T) {
// Mock database
db, mock, _ := sqlmock.New()
oauth2Auth := security.NewGoogleAuthenticator(
"test-client-id",
"test-client-secret",
"http://localhost/callback",
db,
)
// Test state generation
state, err := oauth2Auth.GenerateState()
assert.NoError(t, err)
assert.NotEmpty(t, state)
// Test auth URL generation
authURL := oauth2Auth.GetAuthURL(state)
assert.Contains(t, authURL, "accounts.google.com")
assert.Contains(t, authURL, state)
}
```
## API Reference
### DatabaseAuthenticator with OAuth2
| Method | Description |
|--------|-------------|
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
| OAuth2GenerateState() | Generates random state for CSRF protection |
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
| Login(ctx, req) | Standard username/password login |
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
### Pre-configured Constructors
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
## Examples
Complete working examples available in `oauth2_examples.go`:
- Basic Google OAuth2
- GitHub OAuth2
- Custom provider
- Multi-provider setup
- Token refresh
- Logout flow
- Complete integration with security middleware

View File

@@ -0,0 +1,281 @@
# OAuth2 Refresh Token - Quick Reference
## Quick Setup (3 Steps)
### 1. Initialize Authenticator
```go
auth := security.NewGoogleAuthenticator(
"client-id",
"client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
```
### 2. OAuth2 Login Flow
```go
// Login - Redirect to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback - Store tokens
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, _ := auth.OAuth2HandleCallback(
r.Context(),
"google",
r.URL.Query().Get("code"),
r.URL.Query().Get("state"),
)
// Save refresh_token on client
// loginResp.RefreshToken - Store this securely!
// loginResp.Token - Session token for API calls
})
```
### 3. Refresh Endpoint
```go
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh token
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
if err != nil {
http.Error(w, err.Error(), 401)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
---
## Multi-Provider Example
```go
// Configure multiple providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ProviderName: "google",
ClientID: "google-client-id",
ClientSecret: "google-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
}).
WithOAuth2(security.OAuth2Config{
ProviderName: "github",
ClientID: "github-client-id",
ClientSecret: "github-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
})
// Refresh with provider selection
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google" or "github"
}
json.NewDecoder(r.Body).Decode(&req)
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), 401)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
---
## Client-Side JavaScript
```javascript
// Automatic token refresh on 401
async function apiCall(url) {
let response = await fetch(url, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
// Token expired - refresh it
if (response.status === 401) {
await refreshToken();
// Retry request with new token
response = await fetch(url, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
}
return response.json();
}
async function refreshToken() {
const response = await fetch('/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
refresh_token: localStorage.getItem('refresh_token'),
provider: localStorage.getItem('provider')
})
});
if (response.ok) {
const data = await response.json();
localStorage.setItem('access_token', data.token);
localStorage.setItem('refresh_token', data.refresh_token);
} else {
// Refresh failed - redirect to login
window.location.href = '/login';
}
}
```
---
## API Methods
| Method | Parameters | Returns |
|--------|-----------|---------|
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
| `OAuth2GenerateState` | none | `string, error` |
| `OAuth2GetProviders` | none | `[]string` |
---
## LoginResponse Structure
```go
type LoginResponse struct {
Token string // New session token for API calls
RefreshToken string // Refresh token (store securely)
User *UserContext // User information
ExpiresIn int64 // Seconds until token expires
}
```
---
## Database Stored Procedures
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
- `resolvespec_oauth_getuser(user_id)` - Get user data
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
---
## Common Errors
| Error | Cause | Solution |
|-------|-------|----------|
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
---
## Security Checklist
- [ ] Use HTTPS for all OAuth2 endpoints
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
- [ ] Implement rate limiting on refresh endpoint
- [ ] Log refresh attempts for audit
- [ ] Rotate tokens on refresh
- [ ] Revoke old sessions after successful refresh
---
## Testing
```bash
# 1. Login and get refresh token
curl http://localhost:8080/auth/google/login
# Follow OAuth2 flow, get refresh_token from callback response
# 2. Refresh token
curl -X POST http://localhost:8080/auth/refresh \
-H "Content-Type: application/json" \
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
# 3. Use new token
curl http://localhost:8080/api/protected \
-H "Authorization: Bearer sess_abc123..."
```
---
## Pre-configured Providers
```go
// Google
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
// GitHub
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
// Microsoft
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
// Facebook
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
// All providers at once
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
"google": {...},
"github": {...},
})
```
---
## Provider-Specific Notes
### Google
- Add `access_type=offline` to get refresh token
- Add `prompt=consent` to force consent screen
```go
authURL += "&access_type=offline&prompt=consent"
```
### GitHub
- Refresh tokens not always provided
- May need to request `offline_access` scope
### Microsoft
- Use `offline_access` scope for refresh token
### Facebook
- Tokens expire after 60 days by default
- Check app settings for token expiration policy
---
## Complete Example
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.

View File

@@ -0,0 +1,495 @@
# OAuth2 Refresh Token Implementation
## Overview
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
## Implementation Status: ✅ COMPLETE
### Components Implemented
1. **✅ Database Schema** - Tables and stored procedures
2. **✅ Go Methods** - OAuth2RefreshToken implementation
3. **✅ Thread Safety** - Mutex protection for provider map
4. **✅ Examples** - Working code examples
5. **✅ Documentation** - Complete API reference
---
## 1. Database Schema
### Tables Modified
```sql
-- user_sessions table with OAuth2 token fields
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT, -- OAuth2 access token
refresh_token TEXT, -- OAuth2 refresh token
token_type VARCHAR(50), -- "Bearer", etc.
auth_provider VARCHAR(50) -- "google", "github", etc.
);
```
### Stored Procedures
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
- Gets OAuth2 session data by refresh token
- Returns: `{user_id, access_token, token_type, expiry}`
- Location: `database_schema.sql:714`
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
- Updates session with new tokens after refresh
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
- Location: `database_schema.sql:752`
**`resolvespec_oauth_getuser(p_user_id)`**
- Gets user data by ID for building UserContext
- Location: `database_schema.sql:791`
---
## 2. Go Implementation
### Method Signature
```go
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
ctx context.Context,
refreshToken string,
providerName string,
) (*LoginResponse, error)
```
**Location:** `pkg/security/oauth2_methods.go:375`
### Implementation Flow
```
1. Validate provider exists
├─ getOAuth2Provider(providerName) with RLock
└─ Return error if provider not configured
2. Get session from database
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
└─ Parse session data {user_id, access_token, token_type, expiry}
3. Refresh token with OAuth2 provider
├─ Create oauth2.Token from stored data
├─ Use provider.config.TokenSource(ctx, oldToken)
└─ Call tokenSource.Token() to get new token
4. Generate new session token
└─ Use OAuth2GenerateState() for secure random token
5. Update database
├─ Call resolvespec_oauth_updaterefreshtoken()
└─ Store new session_token, access_token, refresh_token
6. Get user data
├─ Call resolvespec_oauth_getuser(user_id)
└─ Build UserContext
7. Return LoginResponse
└─ {Token, RefreshToken, User, ExpiresIn}
```
### Thread Safety
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
```go
type DatabaseAuthenticator struct {
oauth2Providers map[string]*OAuth2Provider
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
}
// Read operations use RLock
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
a.oauth2ProvidersMutex.RLock()
defer a.oauth2ProvidersMutex.RUnlock()
// ... access map
}
// Write operations use Lock
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
a.oauth2ProvidersMutex.Lock()
defer a.oauth2ProvidersMutex.Unlock()
// ... modify map
}
```
---
## 3. Usage Examples
### Single Provider (Google)
```go
package main
import (
"database/sql"
"encoding/json"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
"github.com/gorilla/mux"
)
func main() {
db, _ := sql.Open("postgres", "connection-string")
// Create Google OAuth2 authenticator
auth := security.NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Token refresh endpoint
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh token (provider name defaults to "google")
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
json.NewEncoder(w).Encode(loginResp)
})
http.ListenAndServe(":8080", router)
}
```
### Multi-Provider Setup
```go
// Single authenticator with multiple OAuth2 providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(security.OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
})
// Refresh endpoint with provider selection
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google" or "github"
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh with specific provider
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
### Client-Side Usage
```javascript
// JavaScript client example
async function refreshAccessToken() {
const refreshToken = localStorage.getItem('refresh_token');
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
const response = await fetch('/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
refresh_token: refreshToken,
provider: provider
})
});
if (response.ok) {
const data = await response.json();
// Store new tokens
localStorage.setItem('access_token', data.token);
localStorage.setItem('refresh_token', data.refresh_token);
console.log('Token refreshed successfully');
return data.token;
} else {
// Refresh failed - redirect to login
window.location.href = '/login';
}
}
// Automatically refresh token when API returns 401
async function apiCall(endpoint) {
let response = await fetch(endpoint, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
if (response.status === 401) {
// Token expired - try refresh
const newToken = await refreshAccessToken();
// Retry with new token
response = await fetch(endpoint, {
headers: {
'Authorization': 'Bearer ' + newToken
}
});
}
return response.json();
}
```
---
## 4. API Reference
### DatabaseAuthenticator Methods
| Method | Signature | Description |
|--------|-----------|-------------|
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
### LoginResponse Structure
```go
type LoginResponse struct {
Token string // New session token
RefreshToken string // New refresh token (may be same as input)
User *UserContext // User information
ExpiresIn int64 // Seconds until expiration
}
type UserContext struct {
UserID int // Database user ID
UserName string // Username
Email string // Email address
UserLevel int // Permission level
SessionID string // Session token
RemoteID string // OAuth2 provider user ID
Roles []string // User roles
Claims map[string]any // Additional claims
}
```
---
## 5. Important Notes
### Provider Configuration
**For Google:** Add `access_type=offline` to get refresh token on first login:
```go
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
// When generating auth URL, add access_type parameter
authURL, _ := auth.OAuth2GetAuthURL("google", state)
authURL += "&access_type=offline&prompt=consent"
```
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
### Token Storage
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
- Never log refresh tokens
- Refresh tokens are long-lived (days/months depending on provider)
- Access tokens are short-lived (minutes/hours)
### Error Handling
Common errors:
- `"invalid or expired refresh token"` - Token expired or revoked
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
- `"failed to refresh token with provider"` - Provider rejected refresh request
### Security Best Practices
1. **Always use HTTPS** for token transmission
2. **Store refresh tokens securely** on client
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
4. **Implement token rotation** - issue new refresh token on each refresh
5. **Revoke old tokens** after successful refresh
6. **Rate limit** refresh endpoints
7. **Log refresh attempts** for audit trail
---
## 6. Testing
### Manual Test Flow
1. **Initial Login:**
```bash
curl http://localhost:8080/auth/google/login
# Follow redirect to Google
# Returns to callback with LoginResponse containing refresh_token
```
2. **Wait for Token Expiry (or manually expire in DB)**
3. **Refresh Token:**
```bash
curl -X POST http://localhost:8080/auth/refresh \
-H "Content-Type: application/json" \
-d '{
"refresh_token": "ya29.a0AfH6SMB...",
"provider": "google"
}'
# Response:
{
"token": "sess_abc123...",
"refresh_token": "ya29.a0AfH6SMB...",
"user": {
"user_id": 1,
"user_name": "john_doe",
"email": "john@example.com",
"session_id": "sess_abc123..."
},
"expires_in": 3600
}
```
4. **Use New Token:**
```bash
curl http://localhost:8080/api/protected \
-H "Authorization: Bearer sess_abc123..."
```
### Database Verification
```sql
-- Check session with refresh token
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
FROM user_sessions
WHERE refresh_token = 'ya29.a0AfH6SMB...';
-- Verify token was updated after refresh
SELECT session_token, access_token, refresh_token,
expires_at, last_activity_at
FROM user_sessions
WHERE user_id = 1
ORDER BY created_at DESC
LIMIT 1;
```
---
## 7. Troubleshooting
### "Refresh token not found or expired"
**Cause:** Refresh token doesn't exist in database or session expired
**Solution:**
- Check if initial OAuth2 login stored refresh token
- Verify provider returns refresh token (some require `access_type=offline`)
- Check session hasn't been deleted from database
### "Failed to refresh token with provider"
**Cause:** OAuth2 provider rejected the refresh request
**Possible reasons:**
- Refresh token was revoked by user
- OAuth2 app credentials changed
- Network connectivity issues
- Provider rate limiting
**Solution:**
- Re-authenticate user (full OAuth2 flow)
- Check provider dashboard for app status
- Verify client credentials are correct
### "OAuth2 provider 'xxx' not found"
**Cause:** Provider not registered with `WithOAuth2()`
**Solution:**
```go
// Make sure provider is configured
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ProviderName: "google", // This name must match refresh call
// ... other config
})
// Then use same name in refresh
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
```
---
## 8. Complete Working Example
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
---
## Summary
OAuth2 refresh token functionality is **production-ready** with:
- ✅ Complete database schema with stored procedures
- ✅ Thread-safe Go implementation with mutex protection
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
- ✅ Comprehensive error handling
- ✅ Working code examples
- ✅ Full API documentation
- ✅ Security best practices implemented
**No additional implementation needed - feature is complete and functional.**

View File

@@ -0,0 +1,208 @@
# Passkey Authentication Quick Reference
## Overview
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
## Setup
### Database Schema
Run the passkey SQL schema (in database_schema.sql):
- Creates `user_passkey_credentials` table
- Adds stored procedures for passkey operations
### Go Code
```go
// Create passkey provider
passkeyProvider := security.NewDatabasePasskeyProvider(db,
security.DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
RPOrigin: "https://example.com",
Timeout: 60000,
})
// Create authenticator with passkey support
auth := security.NewDatabaseAuthenticatorWithOptions(db,
security.DatabaseAuthenticatorOptions{
PasskeyProvider: passkeyProvider,
})
// Or add passkey to existing authenticator
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
```
## Registration Flow
### Backend - Step 1: Begin Registration
```go
options, err := auth.BeginPasskeyRegistration(ctx,
security.PasskeyBeginRegistrationRequest{
UserID: 1,
Username: "alice",
DisplayName: "Alice Smith",
})
// Send options to client as JSON
```
### Frontend - Step 2: Create Credential
```javascript
// Convert options from server
options.challenge = base64ToArrayBuffer(options.challenge);
options.user.id = base64ToArrayBuffer(options.user.id);
// Create credential
const credential = await navigator.credentials.create({
publicKey: options
});
// Send credential back to server
```
### Backend - Step 3: Complete Registration
```go
credential, err := auth.CompletePasskeyRegistration(ctx,
security.PasskeyRegisterRequest{
UserID: 1,
Response: clientResponse,
ExpectedChallenge: storedChallenge,
CredentialName: "My iPhone",
})
```
## Authentication Flow
### Backend - Step 1: Begin Authentication
```go
options, err := auth.BeginPasskeyAuthentication(ctx,
security.PasskeyBeginAuthenticationRequest{
Username: "alice", // Optional for resident key
})
// Send options to client as JSON
```
### Frontend - Step 2: Get Credential
```javascript
// Convert options from server
options.challenge = base64ToArrayBuffer(options.challenge);
// Get credential
const credential = await navigator.credentials.get({
publicKey: options
});
// Send assertion back to server
```
### Backend - Step 3: Complete Authentication
```go
loginResponse, err := auth.LoginWithPasskey(ctx,
security.PasskeyLoginRequest{
Response: clientAssertion,
ExpectedChallenge: storedChallenge,
Claims: map[string]any{
"ip_address": "192.168.1.1",
"user_agent": "Mozilla/5.0...",
},
})
// Returns session token and user info
```
## Credential Management
### List Credentials
```go
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
```
### Update Credential Name
```go
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
```
### Delete Credential
```go
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
```
## HTTP Endpoints Example
### POST /api/passkey/register/begin
Request: `{user_id, username, display_name}`
Response: PasskeyRegistrationOptions
### POST /api/passkey/register/complete
Request: `{user_id, response, credential_name}`
Response: PasskeyCredential
### POST /api/passkey/login/begin
Request: `{username}` (optional)
Response: PasskeyAuthenticationOptions
### POST /api/passkey/login/complete
Request: `{response}`
Response: LoginResponse with session token
### GET /api/passkey/credentials
Response: Array of PasskeyCredential
### DELETE /api/passkey/credentials/{id}
Request: `{credential_id}`
Response: 204 No Content
## Database Stored Procedures
- `resolvespec_passkey_store_credential` - Store new credential
- `resolvespec_passkey_get_credential` - Get credential by ID
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
- `resolvespec_passkey_delete_credential` - Delete credential
- `resolvespec_passkey_update_name` - Update credential name
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
## Security Features
- **Clone Detection**: Sign counter validation detects credential cloning
- **Attestation Support**: Stores attestation type (none, indirect, direct)
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
- **Backup State**: Tracks if credential is backed up/synced
- **User Verification**: Supports preferred/required user verification
## Important Notes
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
4. **Browser Support**: Check browser compatibility for WebAuthn API.
5. **Relying Party ID**: Must match your domain exactly.
## Client-Side Helper Functions
```javascript
function base64ToArrayBuffer(base64) {
const binary = atob(base64);
const bytes = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i++) {
bytes[i] = binary.charCodeAt(i);
}
return bytes.buffer;
}
function arrayBufferToBase64(buffer) {
const bytes = new Uint8Array(buffer);
let binary = '';
for (let i = 0; i < bytes.length; i++) {
binary += String.fromCharCode(bytes[i]);
}
return btoa(binary);
}
```
## Testing
Run tests: `go test -v ./pkg/security -run Passkey`
All passkey functionality includes comprehensive tests using sqlmock.

View File

@@ -7,15 +7,16 @@
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended) auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
// OR: auth := security.NewJWTAuthenticator("secret-key", db) // OR: auth := security.NewJWTAuthenticator("secret-key", db)
// OR: auth := security.NewHeaderAuthenticator() // OR: auth := security.NewHeaderAuthenticator()
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
colSec := security.NewDatabaseColumnSecurityProvider(db) colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db) rowSec := security.NewDatabaseRowSecurityProvider(db)
// Step 2: Combine providers // Step 2: Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
// Step 3: Setup and apply middleware // Step 3: Setup and apply middleware
securityList := security.SetupSecurityProvider(handler, provider) securityList, _ := security.SetupSecurityProvider(handler, provider)
router.Use(security.NewAuthMiddleware(securityList)) router.Use(security.NewAuthMiddleware(securityList))
router.Use(security.SetSecurityMiddleware(securityList)) router.Use(security.SetSecurityMiddleware(securityList))
``` ```
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
```go ```go
// DatabaseAuthenticator uses these stored procedures: // DatabaseAuthenticator uses these stored procedures:
resolvespec_login(jsonb) // Login with credentials resolvespec_login(jsonb) // Login with credentials
resolvespec_register(jsonb) // Register new user
resolvespec_logout(jsonb) // Invalidate session resolvespec_logout(jsonb) // Invalidate session
resolvespec_session(text, text) // Validate session token resolvespec_session(text, text) // Validate session token
resolvespec_session_update(text, jsonb) // Update activity timestamp resolvespec_session_update(text, jsonb) // Update activity timestamp
@@ -502,10 +504,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
--- ---
## Login/Logout Endpoints ## Login/Logout/Register Endpoints
```go ```go
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) { func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
// Register
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
var req security.RegisterRequest
json.NewDecoder(r.Body).Decode(&req)
// Check if provider supports registration
registrable, ok := securityList.Provider().(security.Registrable)
if !ok {
http.Error(w, "Registration not supported", http.StatusNotImplemented)
return
}
resp, err := registrable.Register(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(resp)
}).Methods("POST")
// Login // Login
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) { router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
var req security.LoginRequest var req security.LoginRequest
@@ -707,6 +730,7 @@ meta, ok := security.GetUserMeta(ctx)
| File | Description | | File | Description |
|------|-------------| |------|-------------|
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide | | `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
| `examples.go` | Working provider implementations to copy | | `examples.go` | Working provider implementations to copy |
| `setup_example.go` | 6 complete integration examples | | `setup_example.go` | 6 complete integration examples |
| `README.md` | Architecture overview and migration guide | | `README.md` | Architecture overview and migration guide |

View File

@@ -6,6 +6,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Interface-Based** - Type-safe providers instead of callbacks -**Interface-Based** - Type-safe providers instead of callbacks
-**Login/Logout Support** - Built-in authentication lifecycle -**Login/Logout Support** - Built-in authentication lifecycle
-**Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
-**Composable** - Mix and match different providers -**Composable** - Mix and match different providers
-**No Global State** - Each handler has its own security configuration -**No Global State** - Each handler has its own security configuration
-**Testable** - Easy to mock and test -**Testable** - Easy to mock and test
@@ -212,6 +213,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
// Note: Requires JWT library installation for token signing/verification // Note: Requires JWT library installation for token signing/verification
``` ```
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
```go
baseAuth := security.NewDatabaseAuthenticator(db)
// Use in-memory provider (for testing)
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
// Or use database provider (for production)
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
// Requires: users table with totp fields, user_totp_backup_codes table
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
// Supports: TOTP codes, backup codes, QR code generation
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
```
### Column Security Providers ### Column Security Providers
**DatabaseColumnSecurityProvider** - Loads rules from database: **DatabaseColumnSecurityProvider** - Loads rules from database:
@@ -334,7 +352,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
return return
} }
```
## Two-Factor Authentication (2FA)
### Overview
- **Optional per-user** - Enable/disable 2FA individually
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
- **Backup codes** - One-time recovery codes with secure hashing
- **Clock skew** - Handles time differences between client/server
### Setup
```go
// 1. Wrap existing authenticator with 2FA support
baseAuth := security.NewDatabaseAuthenticator(db)
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
// 2. Use as normal authenticator
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
```
### Enable 2FA for User
```go
// 1. Initiate 2FA setup
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
// 2. User scans QR code with authenticator app
// Display secret.QRCodeURL as QR code image
// 3. User enters verification code from app
code := "123456" // From authenticator app
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
// 2FA is now enabled for this user
// 4. Store backup codes securely and show to user once
// Display: secret.BackupCodes (10 codes)
```
### Login Flow with 2FA
```go
// 1. User provides credentials
req := security.LoginRequest{
Username: "user@example.com",
Password: "password",
}
resp, err := tfaAuth.Login(ctx, req)
// 2. Check if 2FA required
if resp.Requires2FA {
// Prompt user for 2FA code
code := getUserInput() // From authenticator app or backup code
// 3. Login again with 2FA code
req.TwoFactorCode = code
resp, err = tfaAuth.Login(ctx, req)
// 4. Success - token is returned
token := resp.Token
}
```
### Manage 2FA
```go
// Disable 2FA
err := tfaAuth.Disable2FA(userID)
// Regenerate backup codes
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
// Check status
has2FA, err := tfaProvider.Get2FAStatus(userID)
```
### Custom 2FA Storage
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
```go
// Uses PostgreSQL stored procedures for all operations
db := setupDatabase()
// Run migrations from totp_database_schema.sql
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
// - Create user_totp_backup_codes table
// - Create resolvespec_totp_* stored procedures
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
```
**Option 2: Implement Custom Provider**
Implement `TwoFactorAuthProvider` for custom storage:
```go
type DBTwoFactorProvider struct {
db *gorm.DB
}
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
// Store secret and hashed backup codes in database
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
secret, hashCodes(backupCodes), userID).Error
}
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var secret string
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
return secret, err
}
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
```
### Configuration
```go
config := &security.TwoFactorConfig{
Algorithm: "SHA256", // SHA1, SHA256, SHA512
Digits: 8, // 6 or 8
Period: 30, // Seconds per code
SkewWindow: 2, // Accept codes ±2 periods
}
totp := security.NewTOTPGenerator(config)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
```
### API Response Structure
```go
// LoginResponse with 2FA
type LoginResponse struct {
Token string `json:"token"`
Requires2FA bool `json:"requires_2fa"`
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
User *UserContext `json:"user"`
}
// TwoFactorSecret for setup
type TwoFactorSecret struct {
Secret string `json:"secret"` // Base32 encoded
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
}
// UserContext includes 2FA status
type UserContext struct {
UserID int `json:"user_id"`
TwoFactorEnabled bool `json:"two_factor_enabled"`
// ... other fields
}
```
### Security Best Practices
- **Store secrets encrypted** - Never store TOTP secrets in plain text
- **Hash backup codes** - Use SHA-256 before storing
- **Rate limit** - Limit 2FA verification attempts
- **Require password** - Always verify password before disabling 2FA
- **Show backup codes once** - Display only during setup/regeneration
- **Log 2FA events** - Track enable/disable/failed attempts
- **Mark codes as used** - Backup codes are single-use only
json.NewEncoder(w).Encode(resp) json.NewEncoder(w).Encode(resp)
} else { } else {
http.Error(w, "Refresh not supported", http.StatusNotImplemented) http.Error(w, "Refresh not supported", http.StatusNotImplemented)

File diff suppressed because it is too large Load Diff

View File

@@ -7,33 +7,48 @@ import (
// UserContext holds authenticated user information // UserContext holds authenticated user information
type UserContext struct { type UserContext struct {
UserID int `json:"user_id"` UserID int `json:"user_id"`
UserName string `json:"user_name"` UserName string `json:"user_name"`
UserLevel int `json:"user_level"` UserLevel int `json:"user_level"`
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
SessionRID int64 `json:"session_rid"` SessionRID int64 `json:"session_rid"`
RemoteID string `json:"remote_id"` RemoteID string `json:"remote_id"`
Roles []string `json:"roles"` Roles []string `json:"roles"`
Email string `json:"email"` Email string `json:"email"`
Claims map[string]any `json:"claims"` Claims map[string]any `json:"claims"`
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
} }
// LoginRequest contains credentials for login // LoginRequest contains credentials for login
type LoginRequest struct { type LoginRequest struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Claims map[string]any `json:"claims"` // Additional login data TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context Claims map[string]any `json:"claims"` // Additional login data
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
}
// RegisterRequest contains information for new user registration
type RegisterRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
UserLevel int `json:"user_level"`
Roles []string `json:"roles"`
Claims map[string]any `json:"claims"` // Additional registration data
Meta map[string]any `json:"meta"` // Additional metadata
} }
// LoginResponse contains the result of a login attempt // LoginResponse contains the result of a login attempt
type LoginResponse struct { type LoginResponse struct {
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
User *UserContext `json:"user"` User *UserContext `json:"user"`
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
} }
// LogoutRequest contains information for logout // LogoutRequest contains information for logout
@@ -55,6 +70,12 @@ type Authenticator interface {
Authenticate(r *http.Request) (*UserContext, error) Authenticate(r *http.Request) (*UserContext, error)
} }
// Registrable allows providers to support user registration
type Registrable interface {
// Register creates a new user account
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
}
// ColumnSecurityProvider handles column-level security (masking/hiding) // ColumnSecurityProvider handles column-level security (masking/hiding)
type ColumnSecurityProvider interface { type ColumnSecurityProvider interface {
// GetColumnSecurity loads column security rules for a user and entity // GetColumnSecurity loads column security rules for a user and entity

View File

@@ -0,0 +1,615 @@
package security
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
// Example: OAuth2 Authentication with Google
func ExampleOAuth2Google() {
db, _ := sql.Open("postgres", "connection-string")
// Create OAuth2 authenticator for Google
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Login endpoint - redirects to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback endpoint - handles Google response
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
// Return user info as JSON
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 Authentication with GitHub
func ExampleOAuth2GitHub() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGitHubAuthenticator(
"your-github-client-id",
"your-github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
router := mux.NewRouter()
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Custom OAuth2 Provider
func ExampleOAuth2Custom() {
db, _ := sql.Open("postgres", "connection-string")
// Custom OAuth2 provider configuration
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
ClientID: "your-client-id",
ClientSecret: "your-client-secret",
RedirectURL: "http://localhost:8080/auth/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://your-provider.com/oauth/authorize",
TokenURL: "https://your-provider.com/oauth/token",
UserInfoURL: "https://your-provider.com/oauth/userinfo",
ProviderName: "custom-provider",
// Custom user info parser
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
// Extract custom fields from your provider
return &UserContext{
UserName: userInfo["username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["id"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo,
}, nil
},
})
router := mux.NewRouter()
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Multi-Provider OAuth2 with Security Integration
func ExampleOAuth2MultiProvider() {
db, _ := sql.Open("postgres", "connection-string")
// Create OAuth2 authenticators for multiple providers
googleAuth := NewGoogleAuthenticator(
"google-client-id",
"google-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
githubAuth := NewGitHubAuthenticator(
"github-client-id",
"github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
// Create column and row security providers
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
router := mux.NewRouter()
// Google OAuth2 routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := googleAuth.OAuth2GenerateState()
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// GitHub OAuth2 routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := githubAuth.OAuth2GenerateState()
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// Use Google auth for protected routes (or GitHub - both work)
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
// Protected route with authentication
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 with Token Refresh
func ExampleOAuth2TokenRefresh() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Refresh token endpoint
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google", "github", etc.
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// Default to google if not specified
if req.Provider == "" {
req.Provider = "google"
}
// Use OAuth2-specific refresh method
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 Logout
func ExampleOAuth2Logout() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token == "" {
cookie, err := r.Cookie("session_token")
if err == nil {
token = cookie.Value
}
}
if token != "" {
// Get user ID from session
userCtx, err := oauth2Auth.Authenticate(r)
if err == nil {
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
Token: token,
UserID: userCtx.UserID,
})
}
}
// Clear cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Logged out successfully"))
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Complete OAuth2 Integration with Database Setup
func ExampleOAuth2Complete() {
db, _ := sql.Open("postgres", "connection-string")
// Create tables (run once)
setupOAuth2Tables(db)
// Create OAuth2 authenticator
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
// Create security providers
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
router := mux.NewRouter()
// Public routes
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
})
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// Protected routes
protectedRouter := router.PathPrefix("/").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
})
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
Token: userCtx.SessionID,
UserID: userCtx.UserID,
})
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
})
_ = http.ListenAndServe(":8080", router)
}
func setupOAuth2Tables(db *sql.DB) {
// Create tables from database_schema.sql
// This is a helper function - in production, use migrations
ctx := context.Background()
// Create users table if not exists
_, _ = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255),
user_level INTEGER DEFAULT 0,
roles VARCHAR(500),
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP,
remote_id VARCHAR(255),
auth_provider VARCHAR(50)
)
`)
// Create user_sessions table (used for both regular and OAuth2 sessions)
_, _ = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
auth_provider VARCHAR(50)
)
`)
}
// Example: All OAuth2 Providers at Once
func ExampleOAuth2AllProviders() {
db, _ := sql.Open("postgres", "connection-string")
// Create authenticator with ALL OAuth2 providers
auth := NewDatabaseAuthenticator(db).
WithOAuth2(OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
}).
WithOAuth2(OAuth2Config{
ClientID: "microsoft-client-id",
ClientSecret: "microsoft-client-secret",
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
ProviderName: "microsoft",
}).
WithOAuth2(OAuth2Config{
ClientID: "facebook-client-id",
ClientSecret: "facebook-client-secret",
RedirectURL: "http://localhost:8080/auth/facebook/callback",
Scopes: []string{"email"},
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
ProviderName: "facebook",
})
// Get list of configured providers
providers := auth.OAuth2GetProviders()
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
router := mux.NewRouter()
// Google routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// GitHub routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Microsoft routes
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Facebook routes
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Create security list for protected routes
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
// Protected routes work for ALL OAuth2 providers + regular sessions
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
_ = http.ListenAndServe(":8080", router)
}

View File

@@ -0,0 +1,579 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"time"
"golang.org/x/oauth2"
)
// OAuth2Config contains configuration for OAuth2 authentication
type OAuth2Config struct {
ClientID string
ClientSecret string
RedirectURL string
Scopes []string
AuthURL string
TokenURL string
UserInfoURL string
ProviderName string
// Optional: Custom user info parser
// If not provided, will use standard claims (sub, email, name)
UserInfoParser func(userInfo map[string]any) (*UserContext, error)
}
// OAuth2Provider holds configuration and state for a single OAuth2 provider
type OAuth2Provider struct {
config *oauth2.Config
userInfoURL string
userInfoParser func(userInfo map[string]any) (*UserContext, error)
providerName string
states map[string]time.Time // state -> expiry time
statesMutex sync.RWMutex
}
// WithOAuth2 configures OAuth2 support for the DatabaseAuthenticator
// Can be called multiple times to add multiple OAuth2 providers
// Returns the same DatabaseAuthenticator instance for method chaining
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) *DatabaseAuthenticator {
if cfg.ProviderName == "" {
cfg.ProviderName = "oauth2"
}
if cfg.UserInfoParser == nil {
cfg.UserInfoParser = defaultOAuth2UserInfoParser
}
provider := &OAuth2Provider{
config: &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
RedirectURL: cfg.RedirectURL,
Scopes: cfg.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: cfg.AuthURL,
TokenURL: cfg.TokenURL,
},
},
userInfoURL: cfg.UserInfoURL,
userInfoParser: cfg.UserInfoParser,
providerName: cfg.ProviderName,
states: make(map[string]time.Time),
}
// Initialize providers map if needed
a.oauth2ProvidersMutex.Lock()
if a.oauth2Providers == nil {
a.oauth2Providers = make(map[string]*OAuth2Provider)
}
// Register provider
a.oauth2Providers[cfg.ProviderName] = provider
a.oauth2ProvidersMutex.Unlock()
// Start state cleanup goroutine for this provider
go provider.cleanupStates()
return a
}
// OAuth2GetAuthURL returns the OAuth2 authorization URL for redirecting users
func (a *DatabaseAuthenticator) OAuth2GetAuthURL(providerName, state string) (string, error) {
provider, err := a.getOAuth2Provider(providerName)
if err != nil {
return "", err
}
// Store state for validation
provider.statesMutex.Lock()
provider.states[state] = time.Now().Add(10 * time.Minute)
provider.statesMutex.Unlock()
return provider.config.AuthCodeURL(state), nil
}
// OAuth2GenerateState generates a random state string for CSRF protection
func (a *DatabaseAuthenticator) OAuth2GenerateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// OAuth2HandleCallback handles the OAuth2 callback and exchanges code for token
func (a *DatabaseAuthenticator) OAuth2HandleCallback(ctx context.Context, providerName, code, state string) (*LoginResponse, error) {
provider, err := a.getOAuth2Provider(providerName)
if err != nil {
return nil, err
}
// Validate state
if !provider.validateState(state) {
return nil, fmt.Errorf("invalid state parameter")
}
// Exchange code for token
token, err := provider.config.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
// Fetch user info
client := provider.config.Client(ctx, token)
resp, err := client.Get(provider.userInfoURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch user info: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read user info: %w", err)
}
var userInfo map[string]any
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("failed to parse user info: %w", err)
}
// Parse user info
userCtx, err := provider.userInfoParser(userInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
// Get or create user in database
userID, err := a.oauth2GetOrCreateUser(ctx, userCtx, providerName)
if err != nil {
return nil, fmt.Errorf("failed to get or create user: %w", err)
}
userCtx.UserID = userID
// Create session token
sessionToken, err := a.OAuth2GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate session token: %w", err)
}
expiresAt := time.Now().Add(24 * time.Hour)
if token.Expiry.After(time.Now()) {
expiresAt = token.Expiry
}
// Store session in database
err = a.oauth2CreateSession(ctx, sessionToken, userCtx.UserID, token, expiresAt, providerName)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
userCtx.SessionID = sessionToken
return &LoginResponse{
Token: sessionToken,
RefreshToken: token.RefreshToken,
User: userCtx,
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
}, nil
}
// OAuth2GetProviders returns list of configured OAuth2 provider names
func (a *DatabaseAuthenticator) OAuth2GetProviders() []string {
a.oauth2ProvidersMutex.RLock()
defer a.oauth2ProvidersMutex.RUnlock()
if a.oauth2Providers == nil {
return nil
}
providers := make([]string, 0, len(a.oauth2Providers))
for name := range a.oauth2Providers {
providers = append(providers, name)
}
return providers
}
// getOAuth2Provider retrieves a registered OAuth2 provider by name
func (a *DatabaseAuthenticator) getOAuth2Provider(providerName string) (*OAuth2Provider, error) {
a.oauth2ProvidersMutex.RLock()
defer a.oauth2ProvidersMutex.RUnlock()
if a.oauth2Providers == nil {
return nil, fmt.Errorf("OAuth2 not configured - call WithOAuth2() first")
}
provider, ok := a.oauth2Providers[providerName]
if !ok {
// Build provider list without calling OAuth2GetProviders to avoid recursion
providerNames := make([]string, 0, len(a.oauth2Providers))
for name := range a.oauth2Providers {
providerNames = append(providerNames, name)
}
return nil, fmt.Errorf("OAuth2 provider '%s' not found - available providers: %v", providerName, providerNames)
}
return provider, nil
}
// oauth2GetOrCreateUser finds or creates a user based on OAuth2 info using stored procedure
func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userCtx *UserContext, providerName string) (int, error) {
userData := map[string]interface{}{
"username": userCtx.UserName,
"email": userCtx.Email,
"remote_id": userCtx.RemoteID,
"user_level": userCtx.UserLevel,
"roles": userCtx.Roles,
"auth_provider": providerName,
}
userJSON, err := json.Marshal(userData)
if err != nil {
return 0, fmt.Errorf("failed to marshal user data: %w", err)
}
var success bool
var errMsg *string
var userID *int
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID)
if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err)
}
if !success {
if errMsg != nil {
return 0, fmt.Errorf("%s", *errMsg)
}
return 0, fmt.Errorf("failed to get or create user")
}
if userID == nil {
return 0, fmt.Errorf("user ID not returned")
}
return *userID, nil
}
// oauth2CreateSession creates a new OAuth2 session using stored procedure
func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, sessionToken string, userID int, token *oauth2.Token, expiresAt time.Time, providerName string) error {
sessionData := map[string]interface{}{
"session_token": sessionToken,
"user_id": userID,
"access_token": token.AccessToken,
"refresh_token": token.RefreshToken,
"token_type": token.TokenType,
"expires_at": expiresAt,
"auth_provider": providerName,
}
sessionJSON, err := json.Marshal(sessionData)
if err != nil {
return fmt.Errorf("failed to marshal session data: %w", err)
}
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to create session")
}
return nil
}
// validateState validates state using in-memory storage
func (p *OAuth2Provider) validateState(state string) bool {
p.statesMutex.Lock()
defer p.statesMutex.Unlock()
expiry, ok := p.states[state]
if !ok {
return false
}
if time.Now().After(expiry) {
delete(p.states, state)
return false
}
delete(p.states, state) // One-time use
return true
}
// cleanupStates removes expired states periodically
func (p *OAuth2Provider) cleanupStates() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
p.statesMutex.Lock()
now := time.Now()
for state, expiry := range p.states {
if now.After(expiry) {
delete(p.states, state)
}
}
p.statesMutex.Unlock()
}
}
// defaultOAuth2UserInfoParser parses standard OAuth2 user info claims
func defaultOAuth2UserInfoParser(userInfo map[string]any) (*UserContext, error) {
ctx := &UserContext{
Claims: userInfo,
Roles: []string{"user"},
}
// Extract standard claims
if sub, ok := userInfo["sub"].(string); ok {
ctx.RemoteID = sub
}
if email, ok := userInfo["email"].(string); ok {
ctx.Email = email
// Use email as username if name not available
ctx.UserName = strings.Split(email, "@")[0]
}
if name, ok := userInfo["name"].(string); ok {
ctx.UserName = name
}
if login, ok := userInfo["login"].(string); ok {
ctx.UserName = login // GitHub uses "login"
}
if ctx.UserName == "" {
return nil, fmt.Errorf("could not extract username from user info")
}
return ctx, nil
}
// OAuth2RefreshToken refreshes an expired OAuth2 access token using the refresh token
// Takes the refresh token and returns a new LoginResponse with updated tokens
func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshToken, providerName string) (*LoginResponse, error) {
provider, err := a.getOAuth2Provider(providerName)
if err != nil {
return nil, err
}
// Get session by refresh token from database
var success bool
var errMsg *string
var sessionData []byte
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("invalid or expired refresh token")
}
// Parse session data
var session struct {
UserID int `json:"user_id"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expiry time.Time `json:"expiry"`
}
if err := json.Unmarshal(sessionData, &session); err != nil {
return nil, fmt.Errorf("failed to parse session data: %w", err)
}
// Create oauth2.Token from stored data
oldToken := &oauth2.Token{
AccessToken: session.AccessToken,
TokenType: session.TokenType,
RefreshToken: refreshToken,
Expiry: session.Expiry,
}
// Use OAuth2 provider to refresh the token
tokenSource := provider.config.TokenSource(ctx, oldToken)
newToken, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to refresh token with provider: %w", err)
}
// Generate new session token
newSessionToken, err := a.OAuth2GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate new session token: %w", err)
}
// Update session in database with new tokens
updateData := map[string]interface{}{
"user_id": session.UserID,
"old_refresh_token": refreshToken,
"new_session_token": newSessionToken,
"new_access_token": newToken.AccessToken,
"new_refresh_token": newToken.RefreshToken,
"expires_at": newToken.Expiry,
}
updateJSON, err := json.Marshal(updateData)
if err != nil {
return nil, fmt.Errorf("failed to marshal update data: %w", err)
}
var updateSuccess bool
var updateErrMsg *string
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err)
}
if !updateSuccess {
if updateErrMsg != nil {
return nil, fmt.Errorf("%s", *updateErrMsg)
}
return nil, fmt.Errorf("failed to update session")
}
// Get user data
var userSuccess bool
var userErrMsg *string
var userData []byte
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
}
if !userSuccess {
if userErrMsg != nil {
return nil, fmt.Errorf("%s", *userErrMsg)
}
return nil, fmt.Errorf("failed to get user data")
}
// Parse user context
var userCtx UserContext
if err := json.Unmarshal(userData, &userCtx); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
userCtx.SessionID = newSessionToken
return &LoginResponse{
Token: newSessionToken,
RefreshToken: newToken.RefreshToken,
User: &userCtx,
ExpiresIn: int64(time.Until(newToken.Expiry).Seconds()),
}, nil
}
// Pre-configured OAuth2 factory methods
// NewGoogleAuthenticator creates a DatabaseAuthenticator configured for Google OAuth2
func NewGoogleAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
auth := NewDatabaseAuthenticator(db)
return auth.WithOAuth2(OAuth2Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
})
}
// NewGitHubAuthenticator creates a DatabaseAuthenticator configured for GitHub OAuth2
func NewGitHubAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
auth := NewDatabaseAuthenticator(db)
return auth.WithOAuth2(OAuth2Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
})
}
// NewMicrosoftAuthenticator creates a DatabaseAuthenticator configured for Microsoft OAuth2
func NewMicrosoftAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
auth := NewDatabaseAuthenticator(db)
return auth.WithOAuth2(OAuth2Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
ProviderName: "microsoft",
})
}
// NewFacebookAuthenticator creates a DatabaseAuthenticator configured for Facebook OAuth2
func NewFacebookAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
auth := NewDatabaseAuthenticator(db)
return auth.WithOAuth2(OAuth2Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"email"},
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
ProviderName: "facebook",
})
}
// NewMultiProviderAuthenticator creates a DatabaseAuthenticator with all major OAuth2 providers configured
func NewMultiProviderAuthenticator(db *sql.DB, configs map[string]OAuth2Config) *DatabaseAuthenticator {
auth := NewDatabaseAuthenticator(db)
//nolint:gocritic // OAuth2Config is copied but kept for API simplicity
for _, cfg := range configs {
auth.WithOAuth2(cfg)
}
return auth
}

185
pkg/security/passkey.go Normal file
View File

@@ -0,0 +1,185 @@
package security
import (
"context"
"encoding/json"
"time"
)
// PasskeyCredential represents a stored WebAuthn/FIDO2 credential
type PasskeyCredential struct {
ID string `json:"id"`
UserID int `json:"user_id"`
CredentialID []byte `json:"credential_id"` // Raw credential ID from authenticator
PublicKey []byte `json:"public_key"` // COSE public key
AttestationType string `json:"attestation_type"` // none, indirect, direct
AAGUID []byte `json:"aaguid"` // Authenticator AAGUID
SignCount uint32 `json:"sign_count"` // Signature counter
CloneWarning bool `json:"clone_warning"` // True if cloning detected
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
BackupEligible bool `json:"backup_eligible"` // Credential can be backed up
BackupState bool `json:"backup_state"` // Credential is currently backed up
Name string `json:"name,omitempty"` // User-friendly name
CreatedAt time.Time `json:"created_at"`
LastUsedAt time.Time `json:"last_used_at"`
}
// PasskeyRegistrationOptions contains options for beginning passkey registration
type PasskeyRegistrationOptions struct {
Challenge []byte `json:"challenge"`
RelyingParty PasskeyRelyingParty `json:"rp"`
User PasskeyUser `json:"user"`
PubKeyCredParams []PasskeyCredentialParam `json:"pubKeyCredParams"`
Timeout int64 `json:"timeout,omitempty"` // Milliseconds
ExcludeCredentials []PasskeyCredentialDescriptor `json:"excludeCredentials,omitempty"`
AuthenticatorSelection *PasskeyAuthenticatorSelection `json:"authenticatorSelection,omitempty"`
Attestation string `json:"attestation,omitempty"` // none, indirect, direct, enterprise
Extensions map[string]any `json:"extensions,omitempty"`
}
// PasskeyAuthenticationOptions contains options for beginning passkey authentication
type PasskeyAuthenticationOptions struct {
Challenge []byte `json:"challenge"`
Timeout int64 `json:"timeout,omitempty"`
RelyingPartyID string `json:"rpId,omitempty"`
AllowCredentials []PasskeyCredentialDescriptor `json:"allowCredentials,omitempty"`
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
Extensions map[string]any `json:"extensions,omitempty"`
}
// PasskeyRelyingParty identifies the relying party
type PasskeyRelyingParty struct {
ID string `json:"id"` // Domain (e.g., "example.com")
Name string `json:"name"` // Display name
}
// PasskeyUser identifies the user
type PasskeyUser struct {
ID []byte `json:"id"` // User handle (unique, persistent)
Name string `json:"name"` // Username
DisplayName string `json:"displayName"` // Display name
}
// PasskeyCredentialParam specifies supported public key algorithm
type PasskeyCredentialParam struct {
Type string `json:"type"` // "public-key"
Alg int `json:"alg"` // COSE algorithm identifier (e.g., -7 for ES256, -257 for RS256)
}
// PasskeyCredentialDescriptor describes a credential
type PasskeyCredentialDescriptor struct {
Type string `json:"type"` // "public-key"
ID []byte `json:"id"` // Credential ID
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
}
// PasskeyAuthenticatorSelection specifies authenticator requirements
type PasskeyAuthenticatorSelection struct {
AuthenticatorAttachment string `json:"authenticatorAttachment,omitempty"` // platform, cross-platform
RequireResidentKey bool `json:"requireResidentKey,omitempty"`
ResidentKey string `json:"residentKey,omitempty"` // discouraged, preferred, required
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
}
// PasskeyRegistrationResponse contains the client's registration response
type PasskeyRegistrationResponse struct {
ID string `json:"id"` // Base64URL encoded credential ID
RawID []byte `json:"rawId"` // Raw credential ID
Type string `json:"type"` // "public-key"
Response PasskeyAuthenticatorAttestationResponse `json:"response"`
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
Transports []string `json:"transports,omitempty"`
}
// PasskeyAuthenticatorAttestationResponse contains attestation data
type PasskeyAuthenticatorAttestationResponse struct {
ClientDataJSON []byte `json:"clientDataJSON"`
AttestationObject []byte `json:"attestationObject"`
Transports []string `json:"transports,omitempty"`
}
// PasskeyAuthenticationResponse contains the client's authentication response
type PasskeyAuthenticationResponse struct {
ID string `json:"id"` // Base64URL encoded credential ID
RawID []byte `json:"rawId"` // Raw credential ID
Type string `json:"type"` // "public-key"
Response PasskeyAuthenticatorAssertionResponse `json:"response"`
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
}
// PasskeyAuthenticatorAssertionResponse contains assertion data
type PasskeyAuthenticatorAssertionResponse struct {
ClientDataJSON []byte `json:"clientDataJSON"`
AuthenticatorData []byte `json:"authenticatorData"`
Signature []byte `json:"signature"`
UserHandle []byte `json:"userHandle,omitempty"`
}
// PasskeyProvider handles passkey registration and authentication
type PasskeyProvider interface {
// BeginRegistration creates registration options for a new passkey
BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error)
// CompleteRegistration verifies and stores a new passkey credential
CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error)
// BeginAuthentication creates authentication options for passkey login
BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error)
// CompleteAuthentication verifies a passkey assertion and returns the user
CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error)
// GetCredentials returns all passkey credentials for a user
GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error)
// DeleteCredential removes a passkey credential
DeleteCredential(ctx context.Context, userID int, credentialID string) error
// UpdateCredentialName updates the friendly name of a credential
UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error
}
// PasskeyLoginRequest contains passkey authentication data
type PasskeyLoginRequest struct {
Response PasskeyAuthenticationResponse `json:"response"`
ExpectedChallenge []byte `json:"expected_challenge"`
Claims map[string]any `json:"claims"` // Additional login data
}
// PasskeyRegisterRequest contains passkey registration data
type PasskeyRegisterRequest struct {
UserID int `json:"user_id"`
Response PasskeyRegistrationResponse `json:"response"`
ExpectedChallenge []byte `json:"expected_challenge"`
CredentialName string `json:"credential_name,omitempty"`
}
// PasskeyBeginRegistrationRequest contains options for starting passkey registration
type PasskeyBeginRegistrationRequest struct {
UserID int `json:"user_id"`
Username string `json:"username"`
DisplayName string `json:"display_name"`
}
// PasskeyBeginAuthenticationRequest contains options for starting passkey authentication
type PasskeyBeginAuthenticationRequest struct {
Username string `json:"username,omitempty"` // Optional for resident key flow
}
// ParsePasskeyRegistrationResponse parses a JSON passkey registration response
func ParsePasskeyRegistrationResponse(data []byte) (*PasskeyRegistrationResponse, error) {
var response PasskeyRegistrationResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}
// ParsePasskeyAuthenticationResponse parses a JSON passkey authentication response
func ParsePasskeyAuthenticationResponse(data []byte) (*PasskeyAuthenticationResponse, error) {
var response PasskeyAuthenticationResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil, err
}
return &response, nil
}

View File

@@ -0,0 +1,432 @@
package security
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
)
// PasskeyAuthenticationExample demonstrates passkey (WebAuthn/FIDO2) authentication
func PasskeyAuthenticationExample() {
// Setup database connection
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
// Create passkey provider
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com", // Your domain
RPName: "Example Application", // Display name
RPOrigin: "https://example.com", // Expected origin
Timeout: 60000, // 60 seconds
})
// Create authenticator with passkey support
// Option 1: Pass during creation
_ = NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
PasskeyProvider: passkeyProvider,
})
// Option 2: Use WithPasskey method
auth := NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
ctx := context.Background()
// === REGISTRATION FLOW ===
// Step 1: Begin registration
regOptions, _ := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
UserID: 1,
Username: "alice",
DisplayName: "Alice Smith",
})
// Send regOptions to client as JSON
// Client will call navigator.credentials.create() with these options
_ = regOptions
// Step 2: Complete registration (after client returns credential)
// This would come from the client's navigator.credentials.create() response
clientResponse := PasskeyRegistrationResponse{
ID: "base64-credential-id",
RawID: []byte("raw-credential-id"),
Type: "public-key",
Response: PasskeyAuthenticatorAttestationResponse{
ClientDataJSON: []byte("..."),
AttestationObject: []byte("..."),
},
Transports: []string{"internal"},
}
credential, _ := auth.CompletePasskeyRegistration(ctx, PasskeyRegisterRequest{
UserID: 1,
Response: clientResponse,
ExpectedChallenge: regOptions.Challenge,
CredentialName: "My iPhone",
})
fmt.Printf("Registered credential: %s\n", credential.ID)
// === AUTHENTICATION FLOW ===
// Step 1: Begin authentication
authOptions, _ := auth.BeginPasskeyAuthentication(ctx, PasskeyBeginAuthenticationRequest{
Username: "alice", // Optional - omit for resident key flow
})
// Send authOptions to client as JSON
// Client will call navigator.credentials.get() with these options
_ = authOptions
// Step 2: Complete authentication (after client returns assertion)
// This would come from the client's navigator.credentials.get() response
clientAssertion := PasskeyAuthenticationResponse{
ID: "base64-credential-id",
RawID: []byte("raw-credential-id"),
Type: "public-key",
Response: PasskeyAuthenticatorAssertionResponse{
ClientDataJSON: []byte("..."),
AuthenticatorData: []byte("..."),
Signature: []byte("..."),
},
}
loginResponse, _ := auth.LoginWithPasskey(ctx, PasskeyLoginRequest{
Response: clientAssertion,
ExpectedChallenge: authOptions.Challenge,
Claims: map[string]any{
"ip_address": "192.168.1.1",
"user_agent": "Mozilla/5.0...",
},
})
fmt.Printf("Logged in user: %s with token: %s\n",
loginResponse.User.UserName, loginResponse.Token)
// === CREDENTIAL MANAGEMENT ===
// Get all credentials for a user
credentials, _ := auth.GetPasskeyCredentials(ctx, 1)
for i := range credentials {
fmt.Printf("Credential: %s (created: %s, last used: %s)\n",
credentials[i].Name, credentials[i].CreatedAt, credentials[i].LastUsedAt)
}
// Update credential name
_ = auth.UpdatePasskeyCredentialName(ctx, 1, credential.ID, "My New iPhone")
// Delete credential
_ = auth.DeletePasskeyCredential(ctx, 1, credential.ID)
}
// PasskeyHTTPHandlersExample shows HTTP handlers for passkey authentication
func PasskeyHTTPHandlersExample(auth *DatabaseAuthenticator) {
// Store challenges in session/cache in production
challenges := make(map[string][]byte)
// Begin registration endpoint
http.HandleFunc("/api/passkey/register/begin", func(w http.ResponseWriter, r *http.Request) {
var req struct {
UserID int `json:"user_id"`
Username string `json:"username"`
DisplayName string `json:"display_name"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
options, err := auth.BeginPasskeyRegistration(r.Context(), PasskeyBeginRegistrationRequest{
UserID: req.UserID,
Username: req.Username,
DisplayName: req.DisplayName,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Store challenge for verification (use session ID as key in production)
sessionID := "session-123"
challenges[sessionID] = options.Challenge
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(options)
})
// Complete registration endpoint
http.HandleFunc("/api/passkey/register/complete", func(w http.ResponseWriter, r *http.Request) {
var req struct {
UserID int `json:"user_id"`
Response PasskeyRegistrationResponse `json:"response"`
CredentialName string `json:"credential_name"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
// Get stored challenge (from session in production)
sessionID := "session-123"
challenge := challenges[sessionID]
delete(challenges, sessionID)
credential, err := auth.CompletePasskeyRegistration(r.Context(), PasskeyRegisterRequest{
UserID: req.UserID,
Response: req.Response,
ExpectedChallenge: challenge,
CredentialName: req.CredentialName,
})
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(credential)
})
// Begin authentication endpoint
http.HandleFunc("/api/passkey/login/begin", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Username string `json:"username"` // Optional
}
_ = json.NewDecoder(r.Body).Decode(&req)
options, err := auth.BeginPasskeyAuthentication(r.Context(), PasskeyBeginAuthenticationRequest{
Username: req.Username,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Store challenge for verification (use session ID as key in production)
sessionID := "session-456"
challenges[sessionID] = options.Challenge
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(options)
})
// Complete authentication endpoint
http.HandleFunc("/api/passkey/login/complete", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Response PasskeyAuthenticationResponse `json:"response"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
// Get stored challenge (from session in production)
sessionID := "session-456"
challenge := challenges[sessionID]
delete(challenges, sessionID)
loginResponse, err := auth.LoginWithPasskey(r.Context(), PasskeyLoginRequest{
Response: req.Response,
ExpectedChallenge: challenge,
Claims: map[string]any{
"ip_address": r.RemoteAddr,
"user_agent": r.UserAgent(),
},
})
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResponse.Token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(loginResponse)
})
// List credentials endpoint
http.HandleFunc("/api/passkey/credentials", func(w http.ResponseWriter, r *http.Request) {
// Get user from authenticated session
userCtx, err := auth.Authenticate(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
credentials, err := auth.GetPasskeyCredentials(r.Context(), userCtx.UserID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(credentials)
})
// Delete credential endpoint
http.HandleFunc("/api/passkey/credentials/delete", func(w http.ResponseWriter, r *http.Request) {
userCtx, err := auth.Authenticate(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
var req struct {
CredentialID string `json:"credential_id"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
err = auth.DeletePasskeyCredential(r.Context(), userCtx.UserID, req.CredentialID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
})
}
// PasskeyClientSideExample shows the client-side JavaScript code needed
func PasskeyClientSideExample() string {
return `
// === CLIENT-SIDE JAVASCRIPT FOR PASSKEY AUTHENTICATION ===
// Helper function to convert base64 to ArrayBuffer
function base64ToArrayBuffer(base64) {
const binary = atob(base64);
const bytes = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i++) {
bytes[i] = binary.charCodeAt(i);
}
return bytes.buffer;
}
// Helper function to convert ArrayBuffer to base64
function arrayBufferToBase64(buffer) {
const bytes = new Uint8Array(buffer);
let binary = '';
for (let i = 0; i < bytes.length; i++) {
binary += String.fromCharCode(bytes[i]);
}
return btoa(binary);
}
// === REGISTRATION ===
async function registerPasskey(userId, username, displayName) {
// Step 1: Get registration options from server
const optionsResponse = await fetch('/api/passkey/register/begin', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ user_id: userId, username, display_name: displayName })
});
const options = await optionsResponse.json();
// Convert base64 strings to ArrayBuffers
options.challenge = base64ToArrayBuffer(options.challenge);
options.user.id = base64ToArrayBuffer(options.user.id);
if (options.excludeCredentials) {
options.excludeCredentials = options.excludeCredentials.map(cred => ({
...cred,
id: base64ToArrayBuffer(cred.id)
}));
}
// Step 2: Create credential using WebAuthn API
const credential = await navigator.credentials.create({
publicKey: options
});
// Step 3: Send credential to server
const credentialResponse = {
id: credential.id,
rawId: arrayBufferToBase64(credential.rawId),
type: credential.type,
response: {
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
attestationObject: arrayBufferToBase64(credential.response.attestationObject)
},
transports: credential.response.getTransports ? credential.response.getTransports() : []
};
const completeResponse = await fetch('/api/passkey/register/complete', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
user_id: userId,
response: credentialResponse,
credential_name: 'My Device'
})
});
return await completeResponse.json();
}
// === AUTHENTICATION ===
async function loginWithPasskey(username) {
// Step 1: Get authentication options from server
const optionsResponse = await fetch('/api/passkey/login/begin', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ username })
});
const options = await optionsResponse.json();
// Convert base64 strings to ArrayBuffers
options.challenge = base64ToArrayBuffer(options.challenge);
if (options.allowCredentials) {
options.allowCredentials = options.allowCredentials.map(cred => ({
...cred,
id: base64ToArrayBuffer(cred.id)
}));
}
// Step 2: Get credential using WebAuthn API
const credential = await navigator.credentials.get({
publicKey: options
});
// Step 3: Send assertion to server
const assertionResponse = {
id: credential.id,
rawId: arrayBufferToBase64(credential.rawId),
type: credential.type,
response: {
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
authenticatorData: arrayBufferToBase64(credential.response.authenticatorData),
signature: arrayBufferToBase64(credential.response.signature),
userHandle: credential.response.userHandle ? arrayBufferToBase64(credential.response.userHandle) : null
}
};
const loginResponse = await fetch('/api/passkey/login/complete', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ response: assertionResponse })
});
return await loginResponse.json();
}
// === USAGE ===
// Register a new passkey
document.getElementById('register-btn').addEventListener('click', async () => {
try {
const result = await registerPasskey(1, 'alice', 'Alice Smith');
console.log('Passkey registered:', result);
} catch (error) {
console.error('Registration failed:', error);
}
});
// Login with passkey
document.getElementById('login-btn').addEventListener('click', async () => {
try {
const result = await loginWithPasskey('alice');
console.log('Logged in:', result);
} catch (error) {
console.error('Login failed:', error);
}
});
`
}

View File

@@ -0,0 +1,405 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"time"
)
// DatabasePasskeyProvider implements PasskeyProvider using database storage
type DatabasePasskeyProvider struct {
db *sql.DB
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
}
// DatabasePasskeyProviderOptions configures the passkey provider
type DatabasePasskeyProviderOptions struct {
// RPID is the Relying Party ID (typically your domain, e.g., "example.com")
RPID string
// RPName is the display name for your relying party
RPName string
// RPOrigin is the expected origin (e.g., "https://example.com")
RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64
}
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) *DatabasePasskeyProvider {
if opts.Timeout == 0 {
opts.Timeout = 60000 // 60 seconds default
}
return &DatabasePasskeyProvider{
db: db,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
}
}
// BeginRegistration creates registration options for a new passkey
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
// Generate challenge
challenge := make([]byte, 32)
if _, err := rand.Read(challenge); err != nil {
return nil, fmt.Errorf("failed to generate challenge: %w", err)
}
// Get existing credentials to exclude
credentials, err := p.GetCredentials(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get existing credentials: %w", err)
}
excludeCredentials := make([]PasskeyCredentialDescriptor, 0, len(credentials))
for i := range credentials {
excludeCredentials = append(excludeCredentials, PasskeyCredentialDescriptor{
Type: "public-key",
ID: credentials[i].CredentialID,
Transports: credentials[i].Transports,
})
}
// Create user handle (persistent user ID)
userHandle := []byte(fmt.Sprintf("user_%d", userID))
return &PasskeyRegistrationOptions{
Challenge: challenge,
RelyingParty: PasskeyRelyingParty{
ID: p.rpID,
Name: p.rpName,
},
User: PasskeyUser{
ID: userHandle,
Name: username,
DisplayName: displayName,
},
PubKeyCredParams: []PasskeyCredentialParam{
{Type: "public-key", Alg: -7}, // ES256 (ECDSA with SHA-256)
{Type: "public-key", Alg: -257}, // RS256 (RSASSA-PKCS1-v1_5 with SHA-256)
},
Timeout: p.timeout,
ExcludeCredentials: excludeCredentials,
AuthenticatorSelection: &PasskeyAuthenticatorSelection{
RequireResidentKey: false,
ResidentKey: "preferred",
UserVerification: "preferred",
},
Attestation: "none",
}, nil
}
// CompleteRegistration verifies and stores a new passkey credential
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
// like github.com/go-webauthn/webauthn to properly verify attestation and parse credentials.
func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error) {
// TODO: Implement full WebAuthn verification
// 1. Verify clientDataJSON contains correct challenge and origin
// 2. Parse and verify attestationObject
// 3. Extract public key and credential ID
// 4. Verify attestation signature (if not "none")
// For now, this is a placeholder that stores the credential data
// In production, you MUST use a proper WebAuthn library
credData := map[string]any{
"user_id": userID,
"credential_id": base64.StdEncoding.EncodeToString(response.RawID),
"public_key": base64.StdEncoding.EncodeToString(response.Response.AttestationObject),
"attestation_type": "none",
"sign_count": 0,
"transports": response.Transports,
"backup_eligible": false,
"backup_state": false,
"name": "Passkey",
}
credJSON, err := json.Marshal(credData)
if err != nil {
return nil, fmt.Errorf("failed to marshal credential data: %w", err)
}
var success bool
var errorMsg sql.NullString
var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("failed to store credential")
}
return &PasskeyCredential{
ID: fmt.Sprintf("%d", credentialID.Int64),
UserID: userID,
CredentialID: response.RawID,
PublicKey: response.Response.AttestationObject,
AttestationType: "none",
Transports: response.Transports,
CreatedAt: time.Now(),
LastUsedAt: time.Now(),
}, nil
}
// BeginAuthentication creates authentication options for passkey login
func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error) {
// Generate challenge
challenge := make([]byte, 32)
if _, err := rand.Read(challenge); err != nil {
return nil, fmt.Errorf("failed to generate challenge: %w", err)
}
// If username is provided, get user's credentials
var allowCredentials []PasskeyCredentialDescriptor
if username != "" {
var success bool
var errorMsg sql.NullString
var userID sql.NullInt64
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("failed to get credentials")
}
// Parse credentials
var creds []struct {
ID string `json:"credential_id"`
Transports []string `json:"transports"`
}
if err := json.Unmarshal([]byte(credentialsJSON.String), &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials: %w", err)
}
allowCredentials = make([]PasskeyCredentialDescriptor, 0, len(creds))
for _, cred := range creds {
credID, err := base64.StdEncoding.DecodeString(cred.ID)
if err != nil {
continue
}
allowCredentials = append(allowCredentials, PasskeyCredentialDescriptor{
Type: "public-key",
ID: credID,
Transports: cred.Transports,
})
}
}
return &PasskeyAuthenticationOptions{
Challenge: challenge,
Timeout: p.timeout,
RelyingPartyID: p.rpID,
AllowCredentials: allowCredentials,
UserVerification: "preferred",
}, nil
}
// CompleteAuthentication verifies a passkey assertion and returns the user ID
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
// like github.com/go-webauthn/webauthn to properly verify the assertion signature.
func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error) {
// TODO: Implement full WebAuthn verification
// 1. Verify clientDataJSON contains correct challenge and origin
// 2. Verify authenticatorData
// 3. Verify signature using stored public key
// 4. Update sign counter and check for cloning
// Get credential from database
var success bool
var errorMsg sql.NullString
var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err)
}
if !success {
if errorMsg.Valid {
return 0, fmt.Errorf("%s", errorMsg.String)
}
return 0, fmt.Errorf("credential not found")
}
// Parse credential
var cred struct {
UserID int `json:"user_id"`
SignCount uint32 `json:"sign_count"`
}
if err := json.Unmarshal([]byte(credentialJSON.String), &cred); err != nil {
return 0, fmt.Errorf("failed to parse credential: %w", err)
}
// TODO: Verify signature here
// For now, we'll just update the counter as a placeholder
// Update counter (in production, this should be done after successful verification)
newCounter := cred.SignCount + 1
var updateSuccess bool
var updateError sql.NullString
var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err)
}
if cloneWarning.Valid && cloneWarning.Bool {
return 0, fmt.Errorf("credential cloning detected")
}
return cred.UserID, nil
}
// GetCredentials returns all passkey credentials for a user
func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
var success bool
var errorMsg sql.NullString
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("failed to get credentials")
}
// Parse credentials
var rawCreds []struct {
ID int `json:"id"`
UserID int `json:"user_id"`
CredentialID string `json:"credential_id"`
PublicKey string `json:"public_key"`
AttestationType string `json:"attestation_type"`
AAGUID string `json:"aaguid"`
SignCount uint32 `json:"sign_count"`
CloneWarning bool `json:"clone_warning"`
Transports []string `json:"transports"`
BackupEligible bool `json:"backup_eligible"`
BackupState bool `json:"backup_state"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt time.Time `json:"last_used_at"`
}
if err := json.Unmarshal([]byte(credentialsJSON.String), &rawCreds); err != nil {
return nil, fmt.Errorf("failed to parse credentials: %w", err)
}
credentials := make([]PasskeyCredential, 0, len(rawCreds))
for i := range rawCreds {
raw := rawCreds[i]
credID, err := base64.StdEncoding.DecodeString(raw.CredentialID)
if err != nil {
continue
}
pubKey, err := base64.StdEncoding.DecodeString(raw.PublicKey)
if err != nil {
continue
}
aaguid, _ := base64.StdEncoding.DecodeString(raw.AAGUID)
credentials = append(credentials, PasskeyCredential{
ID: fmt.Sprintf("%d", raw.ID),
UserID: raw.UserID,
CredentialID: credID,
PublicKey: pubKey,
AttestationType: raw.AttestationType,
AAGUID: aaguid,
SignCount: raw.SignCount,
CloneWarning: raw.CloneWarning,
Transports: raw.Transports,
BackupEligible: raw.BackupEligible,
BackupState: raw.BackupState,
Name: raw.Name,
CreatedAt: raw.CreatedAt,
LastUsedAt: raw.LastUsedAt,
})
}
return credentials, nil
}
// DeleteCredential removes a passkey credential
func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID int, credentialID string) error {
credID, err := base64.StdEncoding.DecodeString(credentialID)
if err != nil {
return fmt.Errorf("invalid credential ID: %w", err)
}
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("failed to delete credential")
}
return nil
}
// UpdateCredentialName updates the friendly name of a credential
func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
credID, err := base64.StdEncoding.DecodeString(credentialID)
if err != nil {
return fmt.Errorf("invalid credential ID: %w", err)
}
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to update credential name: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("failed to update credential name")
}
return nil
}

View File

@@ -0,0 +1,330 @@
package security
import (
"context"
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
func TestDatabasePasskeyProvider_BeginRegistration(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
RPOrigin: "https://example.com",
})
ctx := context.Background()
// Mock get credentials query
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
AddRow(true, nil, "[]")
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
WithArgs(1).
WillReturnRows(rows)
opts, err := provider.BeginRegistration(ctx, 1, "testuser", "Test User")
if err != nil {
t.Fatalf("BeginRegistration failed: %v", err)
}
if opts.RelyingParty.ID != "example.com" {
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingParty.ID)
}
if opts.User.Name != "testuser" {
t.Errorf("expected username 'testuser', got '%s'", opts.User.Name)
}
if len(opts.Challenge) != 32 {
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
}
if len(opts.PubKeyCredParams) != 2 {
t.Errorf("expected 2 credential params, got %d", len(opts.PubKeyCredParams))
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabasePasskeyProvider_BeginAuthentication(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
RPOrigin: "https://example.com",
})
ctx := context.Background()
// Mock get credentials by username query
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user_id", "p_credentials"}).
AddRow(true, nil, 1, `[{"credential_id":"YWJjZGVm","transports":["internal"]}]`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username`).
WithArgs("testuser").
WillReturnRows(rows)
opts, err := provider.BeginAuthentication(ctx, "testuser")
if err != nil {
t.Fatalf("BeginAuthentication failed: %v", err)
}
if opts.RelyingPartyID != "example.com" {
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingPartyID)
}
if len(opts.Challenge) != 32 {
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
}
if len(opts.AllowCredentials) != 1 {
t.Errorf("expected 1 allowed credential, got %d", len(opts.AllowCredentials))
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabasePasskeyProvider_GetCredentials(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
})
ctx := context.Background()
credentialsJSON := `[{
"id": 1,
"user_id": 1,
"credential_id": "YWJjZGVmMTIzNDU2",
"public_key": "cHVibGlja2V5",
"attestation_type": "none",
"aaguid": "",
"sign_count": 5,
"clone_warning": false,
"transports": ["internal"],
"backup_eligible": true,
"backup_state": false,
"name": "My Phone",
"created_at": "2026-01-01T00:00:00Z",
"last_used_at": "2026-01-31T00:00:00Z"
}]`
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
AddRow(true, nil, credentialsJSON)
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
WithArgs(1).
WillReturnRows(rows)
credentials, err := provider.GetCredentials(ctx, 1)
if err != nil {
t.Fatalf("GetCredentials failed: %v", err)
}
if len(credentials) != 1 {
t.Fatalf("expected 1 credential, got %d", len(credentials))
}
cred := credentials[0]
if cred.UserID != 1 {
t.Errorf("expected user ID 1, got %d", cred.UserID)
}
if cred.Name != "My Phone" {
t.Errorf("expected name 'My Phone', got '%s'", cred.Name)
}
if cred.SignCount != 5 {
t.Errorf("expected sign count 5, got %d", cred.SignCount)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabasePasskeyProvider_DeleteCredential(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
})
ctx := context.Background()
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
AddRow(true, nil)
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_delete_credential`).
WithArgs(1, sqlmock.AnyArg()).
WillReturnRows(rows)
err = provider.DeleteCredential(ctx, 1, "YWJjZGVmMTIzNDU2")
if err != nil {
t.Errorf("DeleteCredential failed: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabasePasskeyProvider_UpdateCredentialName(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
})
ctx := context.Background()
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
AddRow(true, nil)
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_update_name`).
WithArgs(1, sqlmock.AnyArg(), "New Name").
WillReturnRows(rows)
err = provider.UpdateCredentialName(ctx, 1, "YWJjZGVmMTIzNDU2", "New Name")
if err != nil {
t.Errorf("UpdateCredentialName failed: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabaseAuthenticator_PasskeyMethods(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
})
auth := NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
PasskeyProvider: passkeyProvider,
})
ctx := context.Background()
t.Run("BeginPasskeyRegistration", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
AddRow(true, nil, "[]")
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
WithArgs(1).
WillReturnRows(rows)
opts, err := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
UserID: 1,
Username: "testuser",
DisplayName: "Test User",
})
if err != nil {
t.Errorf("BeginPasskeyRegistration failed: %v", err)
}
if opts == nil {
t.Error("expected options, got nil")
}
})
t.Run("GetPasskeyCredentials", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
AddRow(true, nil, "[]")
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
WithArgs(1).
WillReturnRows(rows)
credentials, err := auth.GetPasskeyCredentials(ctx, 1)
if err != nil {
t.Errorf("GetPasskeyCredentials failed: %v", err)
}
if credentials == nil {
t.Error("expected credentials slice, got nil")
}
})
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDatabaseAuthenticator_WithoutPasskey(t *testing.T) {
db, _, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create mock db: %v", err)
}
defer db.Close()
auth := NewDatabaseAuthenticator(db)
ctx := context.Background()
_, err = auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
UserID: 1,
Username: "testuser",
DisplayName: "Test User",
})
if err == nil {
t.Error("expected error when passkey provider not configured, got nil")
}
expectedMsg := "passkey provider not configured"
if err.Error() != expectedMsg {
t.Errorf("expected error '%s', got '%s'", expectedMsg, err.Error())
}
}
func TestPasskeyProvider_NilDB(t *testing.T) {
// This test verifies that the provider can be created with nil DB
// but operations will fail. In production, always provide a valid DB.
var db *sql.DB
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
})
if provider == nil {
t.Error("expected provider to be created even with nil DB")
}
// Verify that the provider has the correct configuration
if provider.rpID != "example.com" {
t.Errorf("expected RP ID 'example.com', got '%s'", provider.rpID)
}
}

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/cache" "github.com/bitechdev/ResolveSpec/pkg/cache"
@@ -60,10 +61,19 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, // Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
// resolvespec_session_update, resolvespec_refresh_token // resolvespec_session_update, resolvespec_refresh_token
// See database_schema.sql for procedure definitions // See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey()
type DatabaseAuthenticator struct { type DatabaseAuthenticator struct {
db *sql.DB db *sql.DB
cache *cache.Cache cache *cache.Cache
cacheTTL time.Duration cacheTTL time.Duration
// OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider
oauth2ProvidersMutex sync.RWMutex
// Passkey provider (optional)
passkeyProvider PasskeyProvider
} }
// DatabaseAuthenticatorOptions configures the database authenticator // DatabaseAuthenticatorOptions configures the database authenticator
@@ -73,6 +83,8 @@ type DatabaseAuthenticatorOptions struct {
CacheTTL time.Duration CacheTTL time.Duration
// Cache is an optional cache instance. If nil, uses the default cache // Cache is an optional cache instance. If nil, uses the default cache
Cache *cache.Cache Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -92,9 +104,10 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
} }
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
passkeyProvider: opts.PasskeyProvider,
} }
} }
@@ -132,6 +145,41 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return &response, nil return &response, nil
} }
// Register implements Registrable interface
func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error) {
// Convert RegisterRequest to JSON
reqJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal register request: %w", err)
}
// Call resolvespec_register stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("registration failed")
}
// Parse response
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse register response: %w", err)
}
return &response, nil
}
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Convert LogoutRequest to JSON // Convert LogoutRequest to JSON
reqJSON, err := json.Marshal(req) reqJSON, err := json.Marshal(req)
@@ -654,3 +702,135 @@ func generateRandomString(length int) string {
// } // }
// return "" // return ""
// } // }
// Passkey authentication methods
// ==============================
// WithPasskey configures the DatabaseAuthenticator with a passkey provider
func (a *DatabaseAuthenticator) WithPasskey(provider PasskeyProvider) *DatabaseAuthenticator {
a.passkeyProvider = provider
return a
}
// BeginPasskeyRegistration initiates passkey registration for a user
func (a *DatabaseAuthenticator) BeginPasskeyRegistration(ctx context.Context, req PasskeyBeginRegistrationRequest) (*PasskeyRegistrationOptions, error) {
if a.passkeyProvider == nil {
return nil, fmt.Errorf("passkey provider not configured")
}
return a.passkeyProvider.BeginRegistration(ctx, req.UserID, req.Username, req.DisplayName)
}
// CompletePasskeyRegistration completes passkey registration
func (a *DatabaseAuthenticator) CompletePasskeyRegistration(ctx context.Context, req PasskeyRegisterRequest) (*PasskeyCredential, error) {
if a.passkeyProvider == nil {
return nil, fmt.Errorf("passkey provider not configured")
}
cred, err := a.passkeyProvider.CompleteRegistration(ctx, req.UserID, req.Response, req.ExpectedChallenge)
if err != nil {
return nil, err
}
// Update credential name if provided
if req.CredentialName != "" && cred.ID != "" {
_ = a.passkeyProvider.UpdateCredentialName(ctx, req.UserID, cred.ID, req.CredentialName)
}
return cred, nil
}
// BeginPasskeyAuthentication initiates passkey authentication
func (a *DatabaseAuthenticator) BeginPasskeyAuthentication(ctx context.Context, req PasskeyBeginAuthenticationRequest) (*PasskeyAuthenticationOptions, error) {
if a.passkeyProvider == nil {
return nil, fmt.Errorf("passkey provider not configured")
}
return a.passkeyProvider.BeginAuthentication(ctx, req.Username)
}
// LoginWithPasskey authenticates a user using a passkey and creates a session
func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req PasskeyLoginRequest) (*LoginResponse, error) {
if a.passkeyProvider == nil {
return nil, fmt.Errorf("passkey provider not configured")
}
// Verify passkey assertion
userID, err := a.passkeyProvider.CompleteAuthentication(ctx, req.Response, req.ExpectedChallenge)
if err != nil {
return nil, fmt.Errorf("passkey authentication failed: %w", err)
}
// Get user data from database
var username, email, roles string
var userLevel int
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
}
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip
}
if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua
}
}
// Create session
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
// Update last login
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
// Return login response
return &LoginResponse{
Token: sessionToken,
User: &UserContext{
UserID: userID,
UserName: username,
Email: email,
UserLevel: userLevel,
SessionID: sessionToken,
Roles: parseRoles(roles),
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}
// GetPasskeyCredentials returns all passkey credentials for a user
func (a *DatabaseAuthenticator) GetPasskeyCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
if a.passkeyProvider == nil {
return nil, fmt.Errorf("passkey provider not configured")
}
return a.passkeyProvider.GetCredentials(ctx, userID)
}
// DeletePasskeyCredential removes a passkey credential
func (a *DatabaseAuthenticator) DeletePasskeyCredential(ctx context.Context, userID int, credentialID string) error {
if a.passkeyProvider == nil {
return fmt.Errorf("passkey provider not configured")
}
return a.passkeyProvider.DeleteCredential(ctx, userID, credentialID)
}
// UpdatePasskeyCredentialName updates the friendly name of a credential
func (a *DatabaseAuthenticator) UpdatePasskeyCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
if a.passkeyProvider == nil {
return fmt.Errorf("passkey provider not configured")
}
return a.passkeyProvider.UpdateCredentialName(ctx, userID, credentialID, name)
}

View File

@@ -635,6 +635,94 @@ func TestDatabaseAuthenticator(t *testing.T) {
t.Errorf("unfulfilled expectations: %v", err) t.Errorf("unfulfilled expectations: %v", err)
} }
}) })
t.Run("successful registration", func(t *testing.T) {
ctx := context.Background()
req := RegisterRequest{
Username: "newuser",
Password: "password123",
Email: "newuser@example.com",
UserLevel: 1,
Roles: []string{"user"},
}
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(true, nil, `{"token":"abc123","user":{"user_id":1,"user_name":"newuser","email":"newuser@example.com"},"expires_in":86400}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(rows)
resp, err := auth.Register(ctx, req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if resp.Token != "abc123" {
t.Errorf("expected token abc123, got %s", resp.Token)
}
if resp.User.UserName != "newuser" {
t.Errorf("expected username newuser, got %s", resp.User.UserName)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("registration with duplicate username", func(t *testing.T) {
ctx := context.Background()
req := RegisterRequest{
Username: "existinguser",
Password: "password123",
Email: "new@example.com",
UserLevel: 1,
Roles: []string{"user"},
}
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(false, "Username already exists", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(rows)
_, err := auth.Register(ctx, req)
if err == nil {
t.Fatal("expected error for duplicate username")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("registration with duplicate email", func(t *testing.T) {
ctx := context.Background()
req := RegisterRequest{
Username: "newuser2",
Password: "password123",
Email: "existing@example.com",
UserLevel: 1,
Roles: []string{"user"},
}
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
AddRow(false, "Email already exists", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(rows)
_, err := auth.Register(ctx, req)
if err == nil {
t.Fatal("expected error for duplicate email")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
} }
// Test DatabaseAuthenticator RefreshToken // Test DatabaseAuthenticator RefreshToken

188
pkg/security/totp.go Normal file
View File

@@ -0,0 +1,188 @@
package security
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"math"
"net/url"
"strings"
"time"
)
// TwoFactorAuthProvider defines interface for 2FA operations
type TwoFactorAuthProvider interface {
// Generate2FASecret creates a new secret for a user
Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error)
// Validate2FACode verifies a TOTP code
Validate2FACode(secret string, code string) (bool, error)
// Enable2FA activates 2FA for a user (store secret in your database)
Enable2FA(userID int, secret string, backupCodes []string) error
// Disable2FA deactivates 2FA for a user
Disable2FA(userID int) error
// Get2FAStatus checks if user has 2FA enabled
Get2FAStatus(userID int) (bool, error)
// Get2FASecret retrieves the user's 2FA secret
Get2FASecret(userID int) (string, error)
// GenerateBackupCodes creates backup codes for 2FA
GenerateBackupCodes(userID int, count int) ([]string, error)
// ValidateBackupCode checks and consumes a backup code
ValidateBackupCode(userID int, code string) (bool, error)
}
// TwoFactorSecret contains 2FA setup information
type TwoFactorSecret struct {
Secret string `json:"secret"` // Base32 encoded secret
QRCodeURL string `json:"qr_code_url"` // URL for QR code generation
BackupCodes []string `json:"backup_codes"` // One-time backup codes
Issuer string `json:"issuer"` // Application name
AccountName string `json:"account_name"` // User identifier (email/username)
}
// TwoFactorConfig holds TOTP configuration
type TwoFactorConfig struct {
Algorithm string // SHA1, SHA256, SHA512
Digits int // Number of digits in code (6 or 8)
Period int // Time step in seconds (default 30)
SkewWindow int // Number of time steps to check before/after (default 1)
}
// DefaultTwoFactorConfig returns standard TOTP configuration
func DefaultTwoFactorConfig() *TwoFactorConfig {
return &TwoFactorConfig{
Algorithm: "SHA1",
Digits: 6,
Period: 30,
SkewWindow: 1,
}
}
// TOTPGenerator handles TOTP code generation and validation
type TOTPGenerator struct {
config *TwoFactorConfig
}
// NewTOTPGenerator creates a new TOTP generator with config
func NewTOTPGenerator(config *TwoFactorConfig) *TOTPGenerator {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &TOTPGenerator{
config: config,
}
}
// GenerateSecret creates a random base32-encoded secret
func (t *TOTPGenerator) GenerateSecret() (string, error) {
secret := make([]byte, 20)
_, err := rand.Read(secret)
if err != nil {
return "", fmt.Errorf("failed to generate random secret: %w", err)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
}
// GenerateQRCodeURL creates a URL for QR code generation
func (t *TOTPGenerator) GenerateQRCodeURL(secret, issuer, accountName string) string {
params := url.Values{}
params.Set("secret", secret)
params.Set("issuer", issuer)
params.Set("algorithm", t.config.Algorithm)
params.Set("digits", fmt.Sprintf("%d", t.config.Digits))
params.Set("period", fmt.Sprintf("%d", t.config.Period))
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, accountName))
return fmt.Sprintf("otpauth://totp/%s?%s", label, params.Encode())
}
// GenerateCode creates a TOTP code for a given time
func (t *TOTPGenerator) GenerateCode(secret string, timestamp time.Time) (string, error) {
// Decode secret
key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Calculate counter (time steps since Unix epoch)
counter := uint64(timestamp.Unix()) / uint64(t.config.Period)
// Generate HMAC
h := t.getHashFunc()
mac := hmac.New(h, key)
// Convert counter to 8-byte array
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, counter)
mac.Write(buf)
sum := mac.Sum(nil)
// Dynamic truncation
offset := sum[len(sum)-1] & 0x0f
truncated := binary.BigEndian.Uint32(sum[offset:]) & 0x7fffffff
// Generate code with specified digits
code := truncated % uint32(math.Pow10(t.config.Digits))
format := fmt.Sprintf("%%0%dd", t.config.Digits)
return fmt.Sprintf(format, code), nil
}
// ValidateCode checks if a code is valid for the secret
func (t *TOTPGenerator) ValidateCode(secret, code string) (bool, error) {
now := time.Now()
// Check current time and skew window
for i := -t.config.SkewWindow; i <= t.config.SkewWindow; i++ {
timestamp := now.Add(time.Duration(i*t.config.Period) * time.Second)
expected, err := t.GenerateCode(secret, timestamp)
if err != nil {
return false, err
}
if code == expected {
return true, nil
}
}
return false, nil
}
// getHashFunc returns the hash function based on algorithm
func (t *TOTPGenerator) getHashFunc() func() hash.Hash {
switch strings.ToUpper(t.config.Algorithm) {
case "SHA256":
return sha256.New
case "SHA512":
return sha512.New
default:
return sha1.New
}
}
// GenerateBackupCodes creates random backup codes
func GenerateBackupCodes(count int) ([]string, error) {
codes := make([]string, count)
for i := 0; i < count; i++ {
code := make([]byte, 4)
_, err := rand.Read(code)
if err != nil {
return nil, fmt.Errorf("failed to generate backup code: %w", err)
}
codes[i] = fmt.Sprintf("%08X", binary.BigEndian.Uint32(code))
}
return codes, nil
}

View File

@@ -0,0 +1,399 @@
package security_test
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
var ErrInvalidCredentials = errors.New("invalid credentials")
// MockAuthenticator is a simple authenticator for testing 2FA
type MockAuthenticator struct {
users map[string]*security.UserContext
}
func NewMockAuthenticator() *MockAuthenticator {
return &MockAuthenticator{
users: map[string]*security.UserContext{
"testuser": {
UserID: 1,
UserName: "testuser",
Email: "test@example.com",
},
},
}
}
func (m *MockAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
user, exists := m.users[req.Username]
if !exists || req.Password != "password" {
return nil, ErrInvalidCredentials
}
return &security.LoginResponse{
Token: "mock-token",
RefreshToken: "mock-refresh-token",
User: user,
ExpiresIn: 3600,
}, nil
}
func (m *MockAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
return nil
}
func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
return m.users["testuser"], nil
}
func TestTwoFactorAuthenticator_Setup(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup 2FA
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
if err != nil {
t.Fatalf("Setup2FA() error = %v", err)
}
if secret.Secret == "" {
t.Error("Setup2FA() returned empty secret")
}
if secret.QRCodeURL == "" {
t.Error("Setup2FA() returned empty QR code URL")
}
if len(secret.BackupCodes) == 0 {
t.Error("Setup2FA() returned no backup codes")
}
if secret.Issuer != "TestApp" {
t.Errorf("Setup2FA() Issuer = %s, want TestApp", secret.Issuer)
}
if secret.AccountName != "test@example.com" {
t.Errorf("Setup2FA() AccountName = %s, want test@example.com", secret.AccountName)
}
}
func TestTwoFactorAuthenticator_Enable2FA(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup 2FA
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
if err != nil {
t.Fatalf("Setup2FA() error = %v", err)
}
// Generate valid code
totp := security.NewTOTPGenerator(nil)
code, err := totp.GenerateCode(secret.Secret, time.Now())
if err != nil {
t.Fatalf("GenerateCode() error = %v", err)
}
// Enable 2FA with valid code
err = tfaAuth.Enable2FA(1, secret.Secret, code)
if err != nil {
t.Errorf("Enable2FA() error = %v", err)
}
// Verify 2FA is enabled
status, err := provider.Get2FAStatus(1)
if err != nil {
t.Fatalf("Get2FAStatus() error = %v", err)
}
if !status {
t.Error("Enable2FA() did not enable 2FA")
}
}
func TestTwoFactorAuthenticator_Enable2FA_InvalidCode(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup 2FA
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
if err != nil {
t.Fatalf("Setup2FA() error = %v", err)
}
// Try to enable with invalid code
err = tfaAuth.Enable2FA(1, secret.Secret, "000000")
if err == nil {
t.Error("Enable2FA() should fail with invalid code")
}
// Verify 2FA is not enabled
status, _ := provider.Get2FAStatus(1)
if status {
t.Error("Enable2FA() should not enable 2FA with invalid code")
}
}
func TestTwoFactorAuthenticator_Login_Without2FA(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
req := security.LoginRequest{
Username: "testuser",
Password: "password",
}
resp, err := tfaAuth.Login(context.Background(), req)
if err != nil {
t.Fatalf("Login() error = %v", err)
}
if resp.Requires2FA {
t.Error("Login() should not require 2FA when not enabled")
}
if resp.Token == "" {
t.Error("Login() should return token when 2FA not required")
}
}
func TestTwoFactorAuthenticator_Login_With2FA_NoCode(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Try to login without 2FA code
req := security.LoginRequest{
Username: "testuser",
Password: "password",
}
resp, err := tfaAuth.Login(context.Background(), req)
if err != nil {
t.Fatalf("Login() error = %v", err)
}
if !resp.Requires2FA {
t.Error("Login() should require 2FA when enabled")
}
if resp.Token != "" {
t.Error("Login() should not return token when 2FA required but not provided")
}
}
func TestTwoFactorAuthenticator_Login_With2FA_ValidCode(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Generate new valid code for login
newCode, _ := totp.GenerateCode(secret.Secret, time.Now())
// Login with 2FA code
req := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: newCode,
}
resp, err := tfaAuth.Login(context.Background(), req)
if err != nil {
t.Fatalf("Login() error = %v", err)
}
if resp.Requires2FA {
t.Error("Login() should not require 2FA when valid code provided")
}
if resp.Token == "" {
t.Error("Login() should return token when 2FA validated")
}
if !resp.User.TwoFactorEnabled {
t.Error("Login() should set TwoFactorEnabled on user")
}
}
func TestTwoFactorAuthenticator_Login_With2FA_InvalidCode(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Try to login with invalid code
req := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: "000000",
}
_, err := tfaAuth.Login(context.Background(), req)
if err == nil {
t.Error("Login() should fail with invalid 2FA code")
}
}
func TestTwoFactorAuthenticator_Login_WithBackupCode(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Get backup codes
backupCodes, _ := tfaAuth.RegenerateBackupCodes(1, 10)
// Login with backup code
req := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: backupCodes[0],
}
resp, err := tfaAuth.Login(context.Background(), req)
if err != nil {
t.Fatalf("Login() with backup code error = %v", err)
}
if resp.Token == "" {
t.Error("Login() should return token when backup code validated")
}
// Try to use same backup code again
req2 := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: backupCodes[0],
}
_, err = tfaAuth.Login(context.Background(), req2)
if err == nil {
t.Error("Login() should fail when reusing backup code")
}
}
func TestTwoFactorAuthenticator_Disable2FA(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Disable 2FA
err := tfaAuth.Disable2FA(1)
if err != nil {
t.Errorf("Disable2FA() error = %v", err)
}
// Verify 2FA is disabled
status, _ := provider.Get2FAStatus(1)
if status {
t.Error("Disable2FA() did not disable 2FA")
}
// Login should not require 2FA
req := security.LoginRequest{
Username: "testuser",
Password: "password",
}
resp, err := tfaAuth.Login(context.Background(), req)
if err != nil {
t.Fatalf("Login() error = %v", err)
}
if resp.Requires2FA {
t.Error("Login() should not require 2FA after disabling")
}
}
func TestTwoFactorAuthenticator_RegenerateBackupCodes(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
// Setup and enable 2FA
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
totp := security.NewTOTPGenerator(nil)
code, _ := totp.GenerateCode(secret.Secret, time.Now())
tfaAuth.Enable2FA(1, secret.Secret, code)
// Get initial backup codes
codes1, err := tfaAuth.RegenerateBackupCodes(1, 10)
if err != nil {
t.Fatalf("RegenerateBackupCodes() error = %v", err)
}
if len(codes1) != 10 {
t.Errorf("RegenerateBackupCodes() returned %d codes, want 10", len(codes1))
}
// Regenerate backup codes
codes2, err := tfaAuth.RegenerateBackupCodes(1, 10)
if err != nil {
t.Fatalf("RegenerateBackupCodes() error = %v", err)
}
// Old codes should not work
req := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: codes1[0],
}
_, err = tfaAuth.Login(context.Background(), req)
if err == nil {
t.Error("Login() should fail with old backup code after regeneration")
}
// New codes should work
req2 := security.LoginRequest{
Username: "testuser",
Password: "password",
TwoFactorCode: codes2[0],
}
resp, err := tfaAuth.Login(context.Background(), req2)
if err != nil {
t.Fatalf("Login() with new backup code error = %v", err)
}
if resp.Token == "" {
t.Error("Login() should return token with new backup code")
}
}

View File

@@ -0,0 +1,134 @@
package security
import (
"context"
"fmt"
"net/http"
)
// TwoFactorAuthenticator wraps an Authenticator and adds 2FA support
type TwoFactorAuthenticator struct {
baseAuth Authenticator
totp *TOTPGenerator
provider TwoFactorAuthProvider
}
// NewTwoFactorAuthenticator creates a new 2FA-enabled authenticator
func NewTwoFactorAuthenticator(baseAuth Authenticator, provider TwoFactorAuthProvider, config *TwoFactorConfig) *TwoFactorAuthenticator {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &TwoFactorAuthenticator{
baseAuth: baseAuth,
totp: NewTOTPGenerator(config),
provider: provider,
}
}
// Login authenticates with 2FA support
func (t *TwoFactorAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// First, perform standard authentication
resp, err := t.baseAuth.Login(ctx, req)
if err != nil {
return nil, err
}
// Check if user has 2FA enabled
if resp.User == nil {
return resp, nil
}
has2FA, err := t.provider.Get2FAStatus(resp.User.UserID)
if err != nil {
return nil, fmt.Errorf("failed to check 2FA status: %w", err)
}
if !has2FA {
// User doesn't have 2FA enabled, return normal response
return resp, nil
}
// User has 2FA enabled
if req.TwoFactorCode == "" {
// No 2FA code provided, require it
resp.Requires2FA = true
resp.Token = "" // Don't return token until 2FA is verified
resp.RefreshToken = ""
return resp, nil
}
// Validate 2FA code
secret, err := t.provider.Get2FASecret(resp.User.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get 2FA secret: %w", err)
}
// Try TOTP code first
valid, err := t.totp.ValidateCode(secret, req.TwoFactorCode)
if err != nil {
return nil, fmt.Errorf("failed to validate 2FA code: %w", err)
}
if !valid {
// Try backup code
valid, err = t.provider.ValidateBackupCode(resp.User.UserID, req.TwoFactorCode)
if err != nil {
return nil, fmt.Errorf("failed to validate backup code: %w", err)
}
}
if !valid {
return nil, fmt.Errorf("invalid 2FA code")
}
// 2FA verified, return full response with token
resp.User.TwoFactorEnabled = true
return resp, nil
}
// Logout delegates to base authenticator
func (t *TwoFactorAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
return t.baseAuth.Logout(ctx, req)
}
// Authenticate delegates to base authenticator
func (t *TwoFactorAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
return t.baseAuth.Authenticate(r)
}
// Setup2FA initiates 2FA setup for a user
func (t *TwoFactorAuthenticator) Setup2FA(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
return t.provider.Generate2FASecret(userID, issuer, accountName)
}
// Enable2FA completes 2FA setup after user confirms with a valid code
func (t *TwoFactorAuthenticator) Enable2FA(userID int, secret, verificationCode string) error {
// Verify the code before enabling
valid, err := t.totp.ValidateCode(secret, verificationCode)
if err != nil {
return fmt.Errorf("failed to validate code: %w", err)
}
if !valid {
return fmt.Errorf("invalid verification code")
}
// Generate backup codes
backupCodes, err := t.provider.GenerateBackupCodes(userID, 10)
if err != nil {
return fmt.Errorf("failed to generate backup codes: %w", err)
}
// Enable 2FA
return t.provider.Enable2FA(userID, secret, backupCodes)
}
// Disable2FA removes 2FA from a user account
func (t *TwoFactorAuthenticator) Disable2FA(userID int) error {
return t.provider.Disable2FA(userID)
}
// RegenerateBackupCodes creates new backup codes for a user
func (t *TwoFactorAuthenticator) RegenerateBackupCodes(userID int, count int) ([]string, error) {
return t.provider.GenerateBackupCodes(userID, count)
}

View File

@@ -0,0 +1,229 @@
package security
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
)
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct {
db *sql.DB
totpGen *TOTPGenerator
}
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &DatabaseTwoFactorProvider{
db: db,
totpGen: NewTOTPGenerator(config),
}
}
// Generate2FASecret creates a new secret for a user
func (p *DatabaseTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
secret, err := p.totpGen.GenerateSecret()
if err != nil {
return nil, fmt.Errorf("failed to generate secret: %w", err)
}
qrURL := p.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
backupCodes, err := GenerateBackupCodes(10)
if err != nil {
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
}
return &TwoFactorSecret{
Secret: secret,
QRCodeURL: qrURL,
BackupCodes: backupCodes,
Issuer: issuer,
AccountName: accountName,
}, nil
}
// Validate2FACode verifies a TOTP code
func (p *DatabaseTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
return p.totpGen.ValidateCode(secret, code)
}
// Enable2FA activates 2FA for a user
func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
// Hash backup codes for secure storage
hashedCodes := make([]string, len(backupCodes))
for i, code := range backupCodes {
hash := sha256.Sum256([]byte(code))
hashedCodes[i] = hex.EncodeToString(hash[:])
}
// Convert to JSON array
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return fmt.Errorf("failed to marshal backup codes: %w", err)
}
// Call stored procedure
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("failed to enable 2FA")
}
return nil
}
// Disable2FA deactivates 2FA for a user
func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("failed to disable 2FA")
}
return nil
}
// Get2FAStatus checks if user has 2FA enabled
func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var success bool
var errorMsg sql.NullString
var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return false, fmt.Errorf("%s", errorMsg.String)
}
return false, fmt.Errorf("failed to get 2FA status")
}
return enabled, nil
}
// Get2FASecret retrieves the user's 2FA secret
func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var success bool
var errorMsg sql.NullString
var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return "", fmt.Errorf("%s", errorMsg.String)
}
return "", fmt.Errorf("failed to get 2FA secret")
}
if !secret.Valid {
return "", fmt.Errorf("2FA secret not found")
}
return secret.String, nil
}
// GenerateBackupCodes creates backup codes for 2FA
func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
codes, err := GenerateBackupCodes(count)
if err != nil {
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
}
// Hash backup codes for storage
hashedCodes := make([]string, len(codes))
for i, code := range codes {
hash := sha256.Sum256([]byte(code))
hashedCodes[i] = hex.EncodeToString(hash[:])
}
// Convert to JSON array
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return nil, fmt.Errorf("failed to marshal backup codes: %w", err)
}
// Call stored procedure
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("failed to regenerate backup codes")
}
// Return unhashed codes to user (only time they see them)
return codes, nil
}
// ValidateBackupCode checks and consumes a backup code
func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
// Hash the code
hash := sha256.Sum256([]byte(code))
codeHash := hex.EncodeToString(hash[:])
var success bool
var errorMsg sql.NullString
var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return false, fmt.Errorf("%s", errorMsg.String)
}
return false, nil
}
return valid, nil
}

View File

@@ -0,0 +1,218 @@
package security_test
import (
"database/sql"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// Note: These tests require a PostgreSQL database with the schema from totp_database_schema.sql
// Set TEST_DATABASE_URL environment variable or skip tests
func setupTestDB(t *testing.T) *sql.DB {
// Skip if no test database configured
t.Skip("Database tests require TEST_DATABASE_URL environment variable")
return nil
}
func TestDatabaseTwoFactorProvider_Enable2FA(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
// Generate secret and backup codes
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
if err != nil {
t.Fatalf("Generate2FASecret() error = %v", err)
}
// Enable 2FA
err = provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
if err != nil {
t.Errorf("Enable2FA() error = %v", err)
}
// Verify enabled
enabled, err := provider.Get2FAStatus(1)
if err != nil {
t.Fatalf("Get2FAStatus() error = %v", err)
}
if !enabled {
t.Error("Get2FAStatus() = false, want true")
}
}
func TestDatabaseTwoFactorProvider_Disable2FA(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
// Enable first
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
// Disable
err := provider.Disable2FA(1)
if err != nil {
t.Errorf("Disable2FA() error = %v", err)
}
// Verify disabled
enabled, err := provider.Get2FAStatus(1)
if err != nil {
t.Fatalf("Get2FAStatus() error = %v", err)
}
if enabled {
t.Error("Get2FAStatus() = true, want false")
}
}
func TestDatabaseTwoFactorProvider_GetSecret(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
// Enable 2FA
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
// Retrieve secret
retrieved, err := provider.Get2FASecret(1)
if err != nil {
t.Errorf("Get2FASecret() error = %v", err)
}
if retrieved != secret.Secret {
t.Errorf("Get2FASecret() = %v, want %v", retrieved, secret.Secret)
}
}
func TestDatabaseTwoFactorProvider_ValidateBackupCode(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
// Enable 2FA
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
// Validate backup code
valid, err := provider.ValidateBackupCode(1, secret.BackupCodes[0])
if err != nil {
t.Errorf("ValidateBackupCode() error = %v", err)
}
if !valid {
t.Error("ValidateBackupCode() = false, want true")
}
// Try to use same code again
valid, err = provider.ValidateBackupCode(1, secret.BackupCodes[0])
if err == nil {
t.Error("ValidateBackupCode() should error on reuse")
}
// Try invalid code
valid, err = provider.ValidateBackupCode(1, "INVALID")
if err != nil {
t.Errorf("ValidateBackupCode() error = %v", err)
}
if valid {
t.Error("ValidateBackupCode() = true for invalid code")
}
}
func TestDatabaseTwoFactorProvider_RegenerateBackupCodes(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
// Enable 2FA
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
// Regenerate codes
newCodes, err := provider.GenerateBackupCodes(1, 10)
if err != nil {
t.Errorf("GenerateBackupCodes() error = %v", err)
}
if len(newCodes) != 10 {
t.Errorf("GenerateBackupCodes() returned %d codes, want 10", len(newCodes))
}
// Old codes should not work
valid, _ := provider.ValidateBackupCode(1, secret.BackupCodes[0])
if valid {
t.Error("Old backup code should not work after regeneration")
}
// New codes should work
valid, err = provider.ValidateBackupCode(1, newCodes[0])
if err != nil {
t.Errorf("ValidateBackupCode() error = %v", err)
}
if !valid {
t.Error("ValidateBackupCode() = false for new code")
}
}
func TestDatabaseTwoFactorProvider_Generate2FASecret(t *testing.T) {
db := setupTestDB(t)
if db == nil {
return
}
defer db.Close()
provider := security.NewDatabaseTwoFactorProvider(db, nil)
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
if err != nil {
t.Fatalf("Generate2FASecret() error = %v", err)
}
if secret.Secret == "" {
t.Error("Generate2FASecret() returned empty secret")
}
if secret.QRCodeURL == "" {
t.Error("Generate2FASecret() returned empty QR code URL")
}
if len(secret.BackupCodes) != 10 {
t.Errorf("Generate2FASecret() returned %d backup codes, want 10", len(secret.BackupCodes))
}
if secret.Issuer != "TestApp" {
t.Errorf("Generate2FASecret() Issuer = %v, want TestApp", secret.Issuer)
}
if secret.AccountName != "test@example.com" {
t.Errorf("Generate2FASecret() AccountName = %v, want test@example.com", secret.AccountName)
}
}

View File

@@ -0,0 +1,156 @@
package security
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sync"
)
// MemoryTwoFactorProvider is an in-memory implementation of TwoFactorAuthProvider for testing/examples
type MemoryTwoFactorProvider struct {
mu sync.RWMutex
secrets map[int]string // userID -> secret
backupCodes map[int]map[string]bool // userID -> backup codes (code -> used)
totpGen *TOTPGenerator
}
// NewMemoryTwoFactorProvider creates a new in-memory 2FA provider
func NewMemoryTwoFactorProvider(config *TwoFactorConfig) *MemoryTwoFactorProvider {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &MemoryTwoFactorProvider{
secrets: make(map[int]string),
backupCodes: make(map[int]map[string]bool),
totpGen: NewTOTPGenerator(config),
}
}
// Generate2FASecret creates a new secret for a user
func (m *MemoryTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
secret, err := m.totpGen.GenerateSecret()
if err != nil {
return nil, err
}
qrURL := m.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
backupCodes, err := GenerateBackupCodes(10)
if err != nil {
return nil, err
}
return &TwoFactorSecret{
Secret: secret,
QRCodeURL: qrURL,
BackupCodes: backupCodes,
Issuer: issuer,
AccountName: accountName,
}, nil
}
// Validate2FACode verifies a TOTP code
func (m *MemoryTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
return m.totpGen.ValidateCode(secret, code)
}
// Enable2FA activates 2FA for a user
func (m *MemoryTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.secrets[userID] = secret
// Store backup codes
if m.backupCodes[userID] == nil {
m.backupCodes[userID] = make(map[string]bool)
}
for _, code := range backupCodes {
// Hash backup codes for security
hash := sha256.Sum256([]byte(code))
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
}
return nil
}
// Disable2FA deactivates 2FA for a user
func (m *MemoryTwoFactorProvider) Disable2FA(userID int) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.secrets, userID)
delete(m.backupCodes, userID)
return nil
}
// Get2FAStatus checks if user has 2FA enabled
func (m *MemoryTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
_, exists := m.secrets[userID]
return exists, nil
}
// Get2FASecret retrieves the user's 2FA secret
func (m *MemoryTwoFactorProvider) Get2FASecret(userID int) (string, error) {
m.mu.RLock()
defer m.mu.RUnlock()
secret, exists := m.secrets[userID]
if !exists {
return "", fmt.Errorf("user does not have 2FA enabled")
}
return secret, nil
}
// GenerateBackupCodes creates backup codes for 2FA
func (m *MemoryTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
codes, err := GenerateBackupCodes(count)
if err != nil {
return nil, err
}
m.mu.Lock()
defer m.mu.Unlock()
// Clear old backup codes and store new ones
m.backupCodes[userID] = make(map[string]bool)
for _, code := range codes {
hash := sha256.Sum256([]byte(code))
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
}
return codes, nil
}
// ValidateBackupCode checks and consumes a backup code
func (m *MemoryTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
userCodes, exists := m.backupCodes[userID]
if !exists {
return false, nil
}
// Hash the provided code
hash := sha256.Sum256([]byte(code))
hashStr := hex.EncodeToString(hash[:])
used, exists := userCodes[hashStr]
if !exists {
return false, nil
}
if used {
return false, fmt.Errorf("backup code already used")
}
// Mark as used
userCodes[hashStr] = true
return true, nil
}

292
pkg/security/totp_test.go Normal file
View File

@@ -0,0 +1,292 @@
package security
import (
"strings"
"testing"
"time"
)
func TestTOTPGenerator_GenerateSecret(t *testing.T) {
totp := NewTOTPGenerator(nil)
secret, err := totp.GenerateSecret()
if err != nil {
t.Fatalf("GenerateSecret() error = %v", err)
}
if secret == "" {
t.Error("GenerateSecret() returned empty secret")
}
// Secret should be base32 encoded
if len(secret) < 16 {
t.Error("GenerateSecret() returned secret that is too short")
}
}
func TestTOTPGenerator_GenerateQRCodeURL(t *testing.T) {
totp := NewTOTPGenerator(nil)
secret := "JBSWY3DPEHPK3PXP"
issuer := "TestApp"
accountName := "user@example.com"
url := totp.GenerateQRCodeURL(secret, issuer, accountName)
if !strings.HasPrefix(url, "otpauth://totp/") {
t.Errorf("GenerateQRCodeURL() = %v, want otpauth://totp/ prefix", url)
}
if !strings.Contains(url, "secret="+secret) {
t.Errorf("GenerateQRCodeURL() missing secret parameter")
}
if !strings.Contains(url, "issuer="+issuer) {
t.Errorf("GenerateQRCodeURL() missing issuer parameter")
}
}
func TestTOTPGenerator_GenerateCode(t *testing.T) {
config := &TwoFactorConfig{
Algorithm: "SHA1",
Digits: 6,
Period: 30,
SkewWindow: 1,
}
totp := NewTOTPGenerator(config)
secret := "JBSWY3DPEHPK3PXP"
// Test with known time
timestamp := time.Unix(1234567890, 0)
code, err := totp.GenerateCode(secret, timestamp)
if err != nil {
t.Fatalf("GenerateCode() error = %v", err)
}
if len(code) != 6 {
t.Errorf("GenerateCode() returned code with length %d, want 6", len(code))
}
// Code should be numeric
for _, c := range code {
if c < '0' || c > '9' {
t.Errorf("GenerateCode() returned non-numeric code: %s", code)
break
}
}
}
func TestTOTPGenerator_ValidateCode(t *testing.T) {
config := &TwoFactorConfig{
Algorithm: "SHA1",
Digits: 6,
Period: 30,
SkewWindow: 1,
}
totp := NewTOTPGenerator(config)
secret := "JBSWY3DPEHPK3PXP"
// Generate a code for current time
now := time.Now()
code, err := totp.GenerateCode(secret, now)
if err != nil {
t.Fatalf("GenerateCode() error = %v", err)
}
// Validate the code
valid, err := totp.ValidateCode(secret, code)
if err != nil {
t.Fatalf("ValidateCode() error = %v", err)
}
if !valid {
t.Error("ValidateCode() = false, want true for current code")
}
// Test with invalid code
valid, err = totp.ValidateCode(secret, "000000")
if err != nil {
t.Fatalf("ValidateCode() error = %v", err)
}
// This might occasionally pass if 000000 is the correct code, but very unlikely
if valid && code != "000000" {
t.Error("ValidateCode() = true for invalid code")
}
}
func TestTOTPGenerator_ValidateCode_WithSkew(t *testing.T) {
config := &TwoFactorConfig{
Algorithm: "SHA1",
Digits: 6,
Period: 30,
SkewWindow: 2, // Allow 2 periods before/after
}
totp := NewTOTPGenerator(config)
secret := "JBSWY3DPEHPK3PXP"
// Generate code for 1 period ago
past := time.Now().Add(-30 * time.Second)
code, err := totp.GenerateCode(secret, past)
if err != nil {
t.Fatalf("GenerateCode() error = %v", err)
}
// Should still validate with skew window
valid, err := totp.ValidateCode(secret, code)
if err != nil {
t.Fatalf("ValidateCode() error = %v", err)
}
if !valid {
t.Error("ValidateCode() = false, want true for code within skew window")
}
}
func TestTOTPGenerator_DifferentAlgorithms(t *testing.T) {
algorithms := []string{"SHA1", "SHA256", "SHA512"}
secret := "JBSWY3DPEHPK3PXP"
for _, algo := range algorithms {
t.Run(algo, func(t *testing.T) {
config := &TwoFactorConfig{
Algorithm: algo,
Digits: 6,
Period: 30,
SkewWindow: 1,
}
totp := NewTOTPGenerator(config)
code, err := totp.GenerateCode(secret, time.Now())
if err != nil {
t.Fatalf("GenerateCode() with %s error = %v", algo, err)
}
valid, err := totp.ValidateCode(secret, code)
if err != nil {
t.Fatalf("ValidateCode() with %s error = %v", algo, err)
}
if !valid {
t.Errorf("ValidateCode() with %s = false, want true", algo)
}
})
}
}
func TestTOTPGenerator_8Digits(t *testing.T) {
config := &TwoFactorConfig{
Algorithm: "SHA1",
Digits: 8,
Period: 30,
SkewWindow: 1,
}
totp := NewTOTPGenerator(config)
secret := "JBSWY3DPEHPK3PXP"
code, err := totp.GenerateCode(secret, time.Now())
if err != nil {
t.Fatalf("GenerateCode() error = %v", err)
}
if len(code) != 8 {
t.Errorf("GenerateCode() returned code with length %d, want 8", len(code))
}
valid, err := totp.ValidateCode(secret, code)
if err != nil {
t.Fatalf("ValidateCode() error = %v", err)
}
if !valid {
t.Error("ValidateCode() = false, want true for 8-digit code")
}
}
func TestGenerateBackupCodes(t *testing.T) {
count := 10
codes, err := GenerateBackupCodes(count)
if err != nil {
t.Fatalf("GenerateBackupCodes() error = %v", err)
}
if len(codes) != count {
t.Errorf("GenerateBackupCodes() returned %d codes, want %d", len(codes), count)
}
// Check uniqueness
seen := make(map[string]bool)
for _, code := range codes {
if seen[code] {
t.Errorf("GenerateBackupCodes() generated duplicate code: %s", code)
}
seen[code] = true
// Check format (8 hex characters)
if len(code) != 8 {
t.Errorf("GenerateBackupCodes() code length = %d, want 8", len(code))
}
}
}
func TestDefaultTwoFactorConfig(t *testing.T) {
config := DefaultTwoFactorConfig()
if config.Algorithm != "SHA1" {
t.Errorf("DefaultTwoFactorConfig() Algorithm = %s, want SHA1", config.Algorithm)
}
if config.Digits != 6 {
t.Errorf("DefaultTwoFactorConfig() Digits = %d, want 6", config.Digits)
}
if config.Period != 30 {
t.Errorf("DefaultTwoFactorConfig() Period = %d, want 30", config.Period)
}
if config.SkewWindow != 1 {
t.Errorf("DefaultTwoFactorConfig() SkewWindow = %d, want 1", config.SkewWindow)
}
}
func TestTOTPGenerator_InvalidSecret(t *testing.T) {
totp := NewTOTPGenerator(nil)
// Test with invalid base32 secret
_, err := totp.GenerateCode("INVALID!!!", time.Now())
if err == nil {
t.Error("GenerateCode() with invalid secret should return error")
}
_, err = totp.ValidateCode("INVALID!!!", "123456")
if err == nil {
t.Error("ValidateCode() with invalid secret should return error")
}
}
// Benchmark tests
func BenchmarkTOTPGenerator_GenerateCode(b *testing.B) {
totp := NewTOTPGenerator(nil)
secret := "JBSWY3DPEHPK3PXP"
now := time.Now()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = totp.GenerateCode(secret, now)
}
}
func BenchmarkTOTPGenerator_ValidateCode(b *testing.B) {
totp := NewTOTPGenerator(nil)
secret := "JBSWY3DPEHPK3PXP"
code, _ := totp.GenerateCode(secret, time.Now())
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = totp.ValidateCode(secret, code)
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -209,10 +210,14 @@ func (h *Handler) handleRead(conn *Connection, msg *Message, hookCtx *HookContex
var metadata map[string]interface{} var metadata map[string]interface{}
var err error var err error
if hookCtx.ID != "" { // Check if FetchRowNumber is specified (treat as single record read)
// Read single record by ID isFetchRowNumber := hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != ""
if hookCtx.ID != "" || isFetchRowNumber {
// Read single record by ID or FetchRowNumber
data, err = h.readByID(hookCtx) data, err = h.readByID(hookCtx)
metadata = map[string]interface{}{"total": 1} metadata = map[string]interface{}{"total": 1}
// The row number is already set on the record itself via setRowNumbersOnRecords
} else { } else {
// Read multiple records // Read multiple records
data, metadata, err = h.readMultiple(hookCtx) data, metadata, err = h.readMultiple(hookCtx)
@@ -509,10 +514,29 @@ func (h *Handler) notifySubscribers(schema, entity string, operation OperationTy
// CRUD operation implementations // CRUD operation implementations
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) { func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
// Handle FetchRowNumber before building query
var fetchedRowNumber *int64
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
if hookCtx.Options != nil && hookCtx.Options.FetchRowNumber != nil && *hookCtx.Options.FetchRowNumber != "" {
fetchRowNumberPKValue := *hookCtx.Options.FetchRowNumber
logger.Debug("[WebSocketSpec] FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
rowNum, err := h.FetchRowNumber(hookCtx.Context, hookCtx.TableName, pkName, fetchRowNumberPKValue, hookCtx.Options, hookCtx.Model)
if err != nil {
return nil, fmt.Errorf("failed to fetch row number: %w", err)
}
fetchedRowNumber = &rowNum
logger.Debug("[WebSocketSpec] FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
// Override ID with FetchRowNumber value
hookCtx.ID = fetchRowNumberPKValue
}
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName) query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter // Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID) query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
// Apply columns // Apply columns
@@ -532,6 +556,12 @@ func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
return nil, fmt.Errorf("failed to read record: %w", err) return nil, fmt.Errorf("failed to read record: %w", err)
} }
// Set the fetched row number on the record if FetchRowNumber was used
if fetchedRowNumber != nil {
logger.Debug("[WebSocketSpec] FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
h.setRowNumbersOnRecords(hookCtx.ModelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
}
return hookCtx.ModelPtr, nil return hookCtx.ModelPtr, nil
} }
@@ -540,10 +570,8 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
// Apply options (simplified implementation) // Apply options (simplified implementation)
if hookCtx.Options != nil { if hookCtx.Options != nil {
// Apply filters // Apply filters with OR grouping support
for _, filter := range hookCtx.Options.Filters { query = h.applyFilters(query, hookCtx.Options.Filters)
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
}
// Apply sorting // Apply sorting
for _, sort := range hookCtx.Options.Sort { for _, sort := range hookCtx.Options.Sort {
@@ -578,6 +606,13 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
return nil, nil, fmt.Errorf("failed to read records: %w", err) return nil, nil, fmt.Errorf("failed to read records: %w", err)
} }
// Set row numbers on records if RowNumber field exists
offset := 0
if hookCtx.Options != nil && hookCtx.Options.Offset != nil {
offset = *hookCtx.Options.Offset
}
h.setRowNumbersOnRecords(hookCtx.ModelPtr, offset)
// Get count // Get count
metadata = make(map[string]interface{}) metadata = make(map[string]interface{})
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName) countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
@@ -656,11 +691,14 @@ func (h *Handler) delete(hookCtx *HookContext) error {
// Helper methods // Helper methods
func (h *Handler) getTableName(schema, entity string, model interface{}) string { func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Use entity as table name
tableName := entity tableName := entity
if schema != "" { if schema != "" {
tableName = schema + "." + tableName if h.db.DriverName() == "sqlite" {
tableName = schema + "_" + tableName
} else {
tableName = schema + "." + tableName
}
} }
return tableName return tableName
} }
@@ -680,6 +718,133 @@ func (h *Handler) getMetadata(schema, entity string, model interface{}) map[stri
} }
// getOperatorSQL converts filter operator to SQL operator // getOperatorSQL converts filter operator to SQL operator
// applyFilters applies all filters with proper grouping for OR logic
// Groups consecutive OR filters together to ensure proper query precedence
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
// Check if this starts an OR group (next filter has OR logic)
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
// Collect all consecutive filters that are OR'd together
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
// Apply the OR group as a single grouped WHERE clause
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
// Single filter with AND logic (or first filter)
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
// applyFilterGroup applies a group of filters that should be OR'd together
// Always wraps them in parentheses and applies as a single WHERE clause
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
// Build all conditions and collect args
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
// Single filter - no need for grouping
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
// Multiple conditions - group with parentheses and OR
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
return query.Where(groupedCondition, args...)
}
// buildFilterCondition builds a filter condition and returns it with args
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
var condition string
var args []interface{}
operatorSQL := h.getOperatorSQL(filter.Operator)
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
args = []interface{}{filter.Value}
return condition, args
}
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
// The row number is calculated as offset + index + 1 (1-based)
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
// Get the reflect value of the records
recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr {
recordsValue = recordsValue.Elem()
}
// Ensure it's a slice
if recordsValue.Kind() != reflect.Slice {
logger.Debug("[WebSocketSpec] setRowNumbersOnRecords: records is not a slice, skipping")
return
}
// Iterate through each record
for i := 0; i < recordsValue.Len(); i++ {
record := recordsValue.Index(i)
// Dereference if it's a pointer
if record.Kind() == reflect.Ptr {
if record.IsNil() {
continue
}
record = record.Elem()
}
// Ensure it's a struct
if record.Kind() != reflect.Struct {
continue
}
// Try to find and set the RowNumber field
rowNumberField := record.FieldByName("RowNumber")
if rowNumberField.IsValid() && rowNumberField.CanSet() {
// Check if the field is of type int64
if rowNumberField.Kind() == reflect.Int64 {
rowNum := int64(offset + i + 1)
rowNumberField.SetInt(rowNum)
logger.Debug("[WebSocketSpec] Set RowNumber=%d for record index %d", rowNum, i)
}
}
}
}
func (h *Handler) getOperatorSQL(operator string) string { func (h *Handler) getOperatorSQL(operator string) string {
switch operator { switch operator {
case "eq": case "eq":
@@ -705,6 +870,92 @@ func (h *Handler) getOperatorSQL(operator string) string {
} }
} }
// FetchRowNumber calculates the row number of a specific record based on sorting and filtering
// Returns the 1-based row number of the record with the given primary key value
func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options *common.RequestOptions, model interface{}) (int64, error) {
defer func() {
if r := recover(); r != nil {
logger.Error("[WebSocketSpec] Panic during FetchRowNumber: %v", r)
}
}()
// Build the sort order SQL
sortSQL := ""
if options != nil && len(options.Sort) > 0 {
sortParts := make([]string, 0, len(options.Sort))
for _, sort := range options.Sort {
if sort.Column == "" {
continue
}
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
}
sortSQL = strings.Join(sortParts, ", ")
} else {
// Default sort by primary key
sortSQL = fmt.Sprintf("%s ASC", pkName)
}
// Build WHERE clause from filters
whereSQL := ""
var whereArgs []interface{}
if options != nil && len(options.Filters) > 0 {
var conditions []string
for _, filter := range options.Filters {
operatorSQL := h.getOperatorSQL(filter.Operator)
conditions = append(conditions, fmt.Sprintf("%s.%s %s ?", tableName, filter.Column, operatorSQL))
whereArgs = append(whereArgs, filter.Value)
}
if len(conditions) > 0 {
whereSQL = "WHERE " + strings.Join(conditions, " AND ")
}
}
// Build the final query with parameterized PK value
queryStr := fmt.Sprintf(`
SELECT search.rn
FROM (
SELECT %[1]s.%[2]s,
ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn
FROM %[1]s
%[4]s
) search
WHERE search.%[2]s = ?
`,
tableName, // [1] - table name
pkName, // [2] - primary key column name
sortSQL, // [3] - sort order SQL
whereSQL, // [4] - WHERE clause
)
logger.Debug("[WebSocketSpec] FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue)
// Append PK value to whereArgs
whereArgs = append(whereArgs, pkValue)
// Execute the raw query with parameterized PK value
var result []struct {
RN int64 `bun:"rn"`
}
err := h.db.Query(ctx, &result, queryStr, whereArgs...)
if err != nil {
return 0, fmt.Errorf("failed to fetch row number: %w", err)
}
if len(result) == 0 {
whereInfo := "none"
if whereSQL != "" {
whereInfo = whereSQL
}
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
}
return result[0].RN, nil
}
// Shutdown gracefully shuts down the handler // Shutdown gracefully shuts down the handler
func (h *Handler) Shutdown() { func (h *Handler) Shutdown() {
h.connManager.Shutdown() h.connManager.Shutdown()

View File

@@ -82,6 +82,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
return args.Get(0) return args.Get(0)
} }
func (m *MockDatabase) DriverName() string {
return "postgres"
}
// MockSelectQuery is a mock implementation of common.SelectQuery // MockSelectQuery is a mock implementation of common.SelectQuery
type MockSelectQuery struct { type MockSelectQuery struct {
mock.Mock mock.Mock

View File

@@ -1,5 +1,50 @@
# Python Implementation of the ResolveSpec API # ResolveSpec Python Client - TODO
# Server ## Client Implementation & Testing
# Client ### 1. ResolveSpec Client API
- [ ] Core API implementation (read, create, update, delete, get_metadata)
- [ ] Unit tests for API functions
- [ ] Integration tests with server
- [ ] Error handling and edge cases
### 2. HeaderSpec Client API
- [ ] Client API implementation
- [ ] Unit tests
- [ ] Integration tests with server
### 3. FunctionSpec Client API
- [ ] Client API implementation
- [ ] Unit tests
- [ ] Integration tests with server
### 4. WebSocketSpec Client API
- [ ] WebSocketClient class implementation (read, create, update, delete, meta, subscribe, unsubscribe)
- [ ] Unit tests for WebSocketClient
- [ ] Connection handling tests
- [ ] Subscription tests
- [ ] Integration tests with server
### 5. Testing Infrastructure
- [ ] Set up test framework (pytest)
- [ ] Configure test coverage reporting (pytest-cov)
- [ ] Add test utilities and fixtures
- [ ] Create test documentation
- [ ] Package and publish to PyPI
## Documentation
- [ ] API reference documentation
- [ ] Usage examples for each client API
- [ ] Installation guide
- [ ] Contributing guidelines
- [ ] README with quick start
---
**Last Updated:** 2026-02-07

114
todo.md
View File

@@ -2,36 +2,98 @@
This document tracks incomplete features and improvements for the ResolveSpec project. This document tracks incomplete features and improvements for the ResolveSpec project.
## In Progress
### Database Layer
- [x] SQLite schema translation (schema.table → schema_table)
- [x] Driver name normalization across adapters
- [x] Database Connection Manager (dbmanager) package
### Documentation ### Documentation
- Ensure all new features are documented in README.md
- Update examples to showcase new functionality
- Add migration notes if any breaking changes are introduced
- [x] Add dbmanager to README
- [x] Add WebSocketSpec to top-level intro
- [x] Add MQTTSpec to top-level intro
- [x] Remove migration sections from README
- [ ] Complete API reference documentation
- [ ] Add examples for all supported databases
### 8. ## Planned Features
1. **Test Coverage**: Increase from 20% to 70%+ ### ResolveSpec JS Client Implementation & Testing
- Add integration tests for CRUD operations
- Add unit tests for security providers 1. **ResolveSpec Client API (resolvespec-js)**
- Add concurrency tests for model registry - [x] Core API implementation (read, create, update, delete, getMetadata)
- [ ] Unit tests for API functions
- [ ] Integration tests with server
- [ ] Error handling and edge cases
2. **HeaderSpec Client API (resolvespec-js)**
- [ ] Client API implementation
- [ ] Unit tests
- [ ] Integration tests with server
3. **FunctionSpec Client API (resolvespec-js)**
- [ ] Client API implementation
- [ ] Unit tests
- [ ] Integration tests with server
4. **WebSocketSpec Client API (resolvespec-js)**
- [x] WebSocketClient class implementation (read, create, update, delete, meta, subscribe, unsubscribe)
- [ ] Unit tests for WebSocketClient
- [ ] Connection handling tests
- [ ] Subscription tests
- [ ] Integration tests with server
5. **resolvespec-js Testing Infrastructure**
- [ ] Set up test framework (Jest or Vitest)
- [ ] Configure test coverage reporting
- [ ] Add test utilities and mocks
- [ ] Create test documentation
### ResolveSpec Python Client Implementation & Testing
See [`resolvespec-python/todo.md`](./resolvespec-python/todo.md) for detailed Python client implementation tasks.
### Core Functionality
1. **Enhanced Preload Filtering**
- [ ] Column selection for nested preloads
- [ ] Advanced filtering conditions for relations
- [ ] Performance optimization for deep nesting
2. **Advanced Query Features**
- [ ] Custom SQL join support
- [ ] Computed column improvements
- [ ] Recursive query support
3. **Testing & Quality**
- [ ] Increase test coverage to 70%+
- [ ] Add integration tests for all ORMs
- [ ] Add concurrency tests for thread safety
- [ ] Performance benchmarks
### Infrastructure
- [ ] Improved error handling and reporting
- [ ] Enhanced logging capabilities
- [ ] Additional monitoring metrics
- [ ] Performance profiling tools
## Documentation Tasks
- [ ] Complete API reference
- [ ] Add troubleshooting guides
- [ ] Create architecture diagrams
- [ ] Expand database adapter documentation
## Known Issues
- [ ] Long preload alias names may exceed PostgreSQL identifier limit
- [ ] Some edge cases in computed column handling
--- ---
## Priority Ranking **Last Updated:** 2026-02-07
**Updated:** Added resolvespec-js client testing and implementation tasks
1. **High Priority**
- Column Selection and Filtering for Preloads (#1)
- Proper Condition Handling for Bun Preloads (#4)
2. **Medium Priority**
- Custom SQL Join Support (#3)
- Recursive JSON Cleaning (#2)
3. **Low Priority**
- Modernize Go Type Declarations (#5)
---
**Last Updated:** 2025-12-09