Compare commits

...

69 Commits

Author SHA1 Message Date
Hein
79a3912f93 fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
2026-04-09 09:19:28 +02:00
Hein
a9bf08f58b feat(security): implement keystore for user authentication keys
* Add ConfigKeyStore for in-memory key management
* Introduce DatabaseKeyStore for PostgreSQL-backed key storage
* Create KeyStoreAuthenticator for API key validation
* Define SQL procedures for key management in PostgreSQL
* Document keystore functionality and usage in KEYSTORE.md
2026-04-07 17:09:17 +02:00
Hein
405a04a192 feat(security): integrate security hooks for access control
Some checks failed
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m6s
Tests / Unit Tests (push) Successful in -30m22s
Tests / Integration Tests (push) Failing after -30m41s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m3s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m36s
Build , Vet Test, and Lint / Build (push) Successful in -29m58s
* Add security hooks for per-entity operation rules and row/column-level security.
* Implement annotation tool for storing and retrieving freeform annotations.
* Enhance handler to support model registration with access rules.
2026-04-07 15:53:12 +02:00
Hein
c1b16d363a feat(db): add DB method to sqlConnection and mongoConnection
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m22s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m59s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m11s
Build , Vet Test, and Lint / Build (push) Successful in -30m12s
Tests / Unit Tests (push) Successful in -30m49s
Tests / Integration Tests (push) Failing after -30m59s
2026-04-01 15:34:09 +02:00
Hein
568df8c6d6 feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s
* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
2026-03-31 14:25:59 +02:00
Hein
aa362c77da fix(cursor): trim parentheses from sort column names 2026-03-27 15:07:10 +02:00
Hein
1641eaf278 feat(resolvemcp): enhance handler with configuration support
* Introduce Config struct for BaseURL and BasePath settings
* Update handler creation functions to accept configuration
* Modify SSEServer to use dynamic base URL detection
* Adjust route setup functions to utilize BasePath from config
2026-03-27 13:56:03 +02:00
Hein
200a03c225 feat(resolvemcp): add SSE server and bunrouter setup functions
* Introduce SSEServer method for creating an SSE server bound to the handler.
* Add SetupBunRouterRoutes function to mount MCP HTTP/SSE endpoints on bunrouter.
* Update README with usage examples for new features.
2026-03-27 13:28:03 +02:00
Hein
7ef9cf39d3 style(tools): simplify string formatting in descriptions 2026-03-27 13:10:50 +02:00
Hein
7f6410f665 feat(resolvemcp): add support for join-column sorting in cursor pagination
* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
2026-03-27 13:10:42 +02:00
Hein
835bbb0727 style(hooks): reorder fields in HookContext for consistency 2026-03-27 12:57:30 +02:00
Hein
047a1cc187 feat(resolvemcp): add hook system for model operations
* Implement hooks for CRUD operations: before/after handle, read, create, update, delete.
* Introduce HookContext and HookRegistry for managing hooks.
* Allow registration and execution of multiple hooks per operation.

feat(resolvemcp): implement MCP tools for CRUD operations
* Register tools for reading, creating, updating, and deleting records.
* Define tool arguments and handle requests with appropriate responses.
* Support for resource registration with metadata.

fix(restheadspec): enhance cursor handling for joins
* Improve cursor filter generation to support lateral joins.
* Update join alias extraction to handle lateral joins correctly.
* Ensure cursor filters do not contain empty comparisons.

test(restheadspec): add tests for cursor filters and join alias extraction
* Create tests for lateral join scenarios in cursor filter generation.
* Validate join alias extraction for various join types, including lateral joins.
2026-03-27 12:57:08 +02:00
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
Hein
22a4ab345a feat(security): add session cookie management functions
* Introduce SessionCookieOptions for configurable session cookies
* Implement SetSessionCookie, GetSessionCookie, and ClearSessionCookie functions
* Enhance cookie handling in DatabaseAuthenticator
2026-03-24 17:11:53 +02:00
Hein
e289c2ed8f fix(handler): restore JoinAliases for proper WHERE sanitization 2026-03-24 12:00:02 +02:00
Hein
0d50bcfee6 fix(provider): enhance file opening logic with alternate path. Handling broken cases to be compatible with Bitech clients
* Implemented alternate path handling for file retrieval
* Improved error messaging for file not found scenarios
2026-03-24 09:02:17 +02:00
4df626ea71 chore(license): update project notice and clarify licensing terms 2026-03-23 20:32:09 +02:00
Hein
7dd630dec2 fix(handler): set default sort to primary key if none provided
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m15s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m11s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m52s
Build , Vet Test, and Lint / Build (push) Successful in -30m44s
Tests / Integration Tests (push) Failing after -31m5s
Tests / Unit Tests (push) Successful in -29m6s
2026-03-11 14:37:04 +02:00
Hein
613bf22cbd fix(cursor): use full schema-qualified table name in filters 2026-03-11 14:25:44 +02:00
d1ae4fe64e refactor(handler): unify filter operator handling for consistency
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m58s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m48s
Build , Vet Test, and Lint / Build (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m39s
Tests / Unit Tests (push) Successful in -30m29s
2026-03-01 13:21:38 +02:00
254102bfac refactor(auth): simplify handler type assertions for middleware
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m31s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m19s
Build , Vet Test, and Lint / Build (push) Successful in -29m42s
Tests / Integration Tests (push) Failing after -30m35s
Tests / Unit Tests (push) Successful in -30m17s
2026-03-01 12:08:36 +02:00
6c27419dbc refactor(auth): enhance request handling with middleware-enriched context 2026-03-01 12:06:43 +02:00
377336caf4 feat(sql): implement IN condition handling with parameterized queries 2026-03-01 09:52:32 +02:00
79720d5421 feat(security): add BeforeHandle hook for auth checks after model resolution
- Implement BeforeHandle hook to enforce authentication based on model rules.
- Integrate with existing security mechanisms to allow or deny access.
- Update documentation to reflect new hook and its usage.
2026-03-01 09:15:30 +02:00
e7ab0a20d6 Merge branch 'main' of github.com:bitechdev/ResolveSpec
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m15s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m41s
Build , Vet Test, and Lint / Build (push) Successful in -29m50s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m12s
Tests / Integration Tests (push) Failing after -30m43s
Tests / Unit Tests (push) Successful in -30m26s
2026-02-28 22:53:26 +02:00
e4087104a9 feat(security): add model rules enforcement for update and delete operations
- Implement BeforeUpdate and BeforeDelete hooks to enforce CanUpdate and CanDelete rules.
- Introduce new security context for websocketspec to manage security hooks.
- Enhance error handling in delete operations to provide clearer feedback.
2026-02-28 22:53:21 +02:00
Hein
17e580a9d3 fix(pgsql): optimize SQL string building for LIMIT and OFFSET
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -24m41s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m32s
Build , Vet Test, and Lint / Build (push) Successful in -29m28s
Build , Vet Test, and Lint / Lint Code (push) Successful in -28m35s
Tests / Integration Tests (push) Failing after -30m28s
Tests / Unit Tests (push) Successful in -29m28s
2026-02-25 09:38:37 +02:00
Hein
337a007d57 feat(openapi): add OpenAPI handling for OPTIONS requests 2026-02-25 09:36:50 +02:00
Hein
e923b0a2a3 feat(routes): add authentication middleware support for routes
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m48s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m25s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m29s
Build , Vet Test, and Lint / Build (push) Successful in -25m41s
Tests / Integration Tests (push) Failing after -26m14s
Tests / Unit Tests (push) Successful in -26m3s
2026-02-16 10:38:06 +02:00
ea4a4371ba feat(changesets): add README and config files for changeset management
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m56s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m29s
Build , Vet Test, and Lint / Build (push) Successful in -25m27s
Build , Vet Test, and Lint / Lint Code (push) Successful in -24m54s
Tests / Integration Tests (push) Failing after -26m14s
Tests / Unit Tests (push) Successful in -25m59s
2026-02-15 19:51:57 +02:00
b3694e50fe refactor(websocket): rename classes and functions for clarity 2026-02-15 19:51:05 +02:00
b76dae5991 refactor(headerspec): improve code formatting and consistency 2026-02-15 19:49:29 +02:00
dc85008d7f feat(api): add ResolveSpec and WebSocket client implementations
- Introduced ResolveSpecClient for REST API interactions.
- Added WebSocketClient for real-time communication.
- Created types and utility functions for both clients.
- Removed deprecated types and example files.
- Configured TypeScript and Vite for building the library.
2026-02-15 15:17:39 +02:00
Hein
fd77385dd6 feat(handler): enhance FetchRowNumber support in handlers
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m39s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m42s
Build , Vet Test, and Lint / Build (push) Successful in -25m55s
Tests / Integration Tests (push) Failing after -26m29s
Tests / Unit Tests (push) Successful in -26m17s
* Implement FetchRowNumber handling in multiple handlers
* Improve error logging for missing rows with filters
* Set row numbers correctly based on FetchRowNumber
2026-02-10 17:42:27 +02:00
Hein
b322ef76a2 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec 2026-02-10 16:55:58 +02:00
Hein
a6c7edb0e4 feat(resolvespec): add OR logic support in filters
* Introduce `logic_operator` field to combine filters with OR logic.
* Implement grouping for consecutive OR filters to ensure proper SQL precedence.
* Add support for custom SQL operators in filter conditions.
* Enhance `fetch_row_number` functionality to return specific record with its position.
* Update tests to cover new filter logic and grouping behavior.

Features Implemented:

  1. OR Logic Filter Support (SearchOr)
    - Added to resolvespec, restheadspec, and websocketspec
    - Consecutive OR filters are automatically grouped with parentheses
    - Prevents SQL logic errors: (A OR B OR C) AND D instead of A OR B OR C AND D
  2. CustomOperators
    - Allows arbitrary SQL conditions in resolvespec
    - Properly integrated with filter logic
  3. FetchRowNumber
    - Uses SQL window functions: ROW_NUMBER() OVER (ORDER BY ...)
    - Returns only the specific record (not all records)
    - Available in resolvespec and restheadspec
    - Perfect for "What's my rank?" queries
  4. RowNumber Field Auto-Population
    - Now available in all three packages: resolvespec, restheadspec, and websocketspec
    - Uses simple offset-based math: offset + index + 1
    - Automatically populates RowNumber int64 field if it exists on models
    - Perfect for displaying paginated lists with sequential numbering
2026-02-10 16:55:55 +02:00
71eeb8315e chore: 📝 Refactored documentation and added better sqlite support.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m40s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m41s
Build , Vet Test, and Lint / Build (push) Successful in -25m55s
Tests / Unit Tests (push) Successful in -26m19s
Tests / Integration Tests (push) Failing after -26m35s
restructure server configuration for multiple instances  - Change server configuration to support multiple instances. - Introduce new fields for tracing and error tracking. - Update example configuration to reflect new structure. - Remove deprecated OpenAPI specification file. - Enhance database adapter to handle SQLite schema translation.
2026-02-07 10:58:34 +02:00
Hein
4bf3d0224e feat(database): normalize driver names across adapters
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m46s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -23m31s
Build , Vet Test, and Lint / Lint Code (push) Successful in -24m55s
Tests / Unit Tests (push) Successful in -26m19s
Build , Vet Test, and Lint / Build (push) Successful in -26m2s
Tests / Integration Tests (push) Failing after -26m42s
* Added DriverName method to BunAdapter, GormAdapter, and PgSQLAdapter for consistent driver name handling.
* Updated transaction adapters to include driver name.
* Enhanced mock database implementations for testing with DriverName method.
* Adjusted getTableName functions to accommodate driver-specific naming conventions.
2026-02-05 13:28:53 +02:00
Hein
50d0caabc2 refactor(database): ♻️ simplify relation type handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m33s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m6s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m1s
Build , Vet Test, and Lint / Build (push) Successful in -26m14s
Tests / Integration Tests (push) Failing after -26m47s
Tests / Unit Tests (push) Successful in -26m35s
* Consolidate related type determination logic
* Improve clarity in slice creation for results
* Enhance foreign key field name handling
2026-02-03 08:40:11 +02:00
Hein
5269ae4de2 style(database): 🎨 format comments for clarity
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m56s
Build , Vet Test, and Lint / Lint Code (push) Failing after -26m11s
Build , Vet Test, and Lint / Build (push) Successful in -26m18s
Tests / Unit Tests (push) Successful in -26m35s
Tests / Integration Tests (push) Failing after -26m49s
2026-02-02 18:40:37 +02:00
Hein
646620ed83 feat(database): add custom preload handling for relations
* Introduced custom preloads to manage relations that may exceed PostgreSQL's identifier limit.
* Implemented checks for alias length to prevent truncation warnings.
* Enhanced the loading mechanism for nested relations using separate queries.
2026-02-02 18:39:48 +02:00
7600a6d1fb fix(security): 🐛 handle errors in OAuth2 examples and passkey methods
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m52s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m42s
Build , Vet Test, and Lint / Build (push) Successful in -26m19s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m40s
Tests / Unit Tests (push) Successful in -26m33s
Tests / Integration Tests (push) Failing after -26m55s
* Add error handling for JSON encoding and HTTP server calls.
* Update passkey examples to improve readability and maintainability.
* Ensure consistent use of error handling across all examples.
2026-01-31 22:58:52 +02:00
2e7b3e7abd feat(security): add database-backed passkey provider
- Implement DatabasePasskeyProvider for WebAuthn/FIDO2 authentication.
- Add methods for registration, authentication, and credential management.
- Create unit tests for passkey provider functionalities.
- Enhance DatabaseAuthenticator to support passkey authentication.
2026-01-31 22:53:33 +02:00
fdf9e118c5 feat(security): Add two-factor authentication support
* Implement TwoFactorAuthenticator for 2FA login.
* Create DatabaseTwoFactorProvider for PostgreSQL integration.
* Add MemoryTwoFactorProvider for in-memory testing.
* Develop TOTPGenerator for generating and validating codes.
* Include tests for all new functionalities.
* Ensure backup codes are securely hashed and validated.
2026-01-31 22:45:28 +02:00
e11e6a8bf7 feat(security): Add OAuth2 authentication examples and methods
* Introduce OAuth2 authentication examples for Google, GitHub, and custom providers.
* Implement OAuth2 methods for handling authentication, token refresh, and logout.
* Create a flexible structure for supporting multiple OAuth2 providers.
* Enhance DatabaseAuthenticator to manage OAuth2 sessions and user creation.
* Add database schema setup for OAuth2 user and session management.
2026-01-31 22:35:40 +02:00
261f98eb29 Merge branch 'main' of github.com:bitechdev/ResolveSpec 2026-01-31 21:50:37 +02:00
0b8d11361c feat(auth): add user registration functionality
* Implemented resolvespec_register stored procedure for user registration.
* Added RegisterRequest struct for registration data.
* Created Register method in DatabaseAuthenticator.
* Updated tests for successful registration and error handling for duplicate usernames and emails.
2026-01-31 21:50:32 +02:00
Hein
e70bab92d7 feat(tests): 🎉 More test for preload fixes.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m10s
Build , Vet Test, and Lint / Build (push) Successful in -26m22s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m12s
Tests / Integration Tests (push) Failing after -26m58s
Tests / Unit Tests (push) Successful in -26m47s
* Implement tests for SanitizeWhereClause and AddTablePrefixToColumns.
* Ensure correct handling of table prefixes in WHERE clauses.
* Validate that unqualified columns are prefixed correctly when necessary.
* Add tests for XFiles processing to verify table name handling.
* Introduce tests for recursive preloads and their related keys.
2026-01-30 10:09:59 +02:00
Hein
fc8f44e3e8 feat(preload): Enhance recursive preload functionality
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m38s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m13s
Build , Vet Test, and Lint / Build (push) Successful in -26m17s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m45s
Tests / Integration Tests (push) Failing after -27m1s
Tests / Unit Tests (push) Successful in -26m48s
* Increase maximum recursion depth from 4 to 8.
* Generate FK-based relation names for child preloads using RelatedKey.
* Clear WHERE clause for recursive preloads to prevent filtering issues.
* Extend child relations to recursive levels for better data retrieval.
* Add integration tests to validate recursive preload behavior and structure.
2026-01-29 15:31:50 +02:00
Hein
584bb9813d .. 2026-01-29 09:37:22 +02:00
Hein
17239d1611 feat(preload): Add support for custom SQL joins
* Introduce SqlJoins and JoinAliases in PreloadOption.
* Preserve SqlJoins and JoinAliases during filter processing.
* Implement logic to apply custom SQL joins in handler.
* Add tests for SqlJoins handling and join alias extraction.
2026-01-29 09:37:09 +02:00
Hein
defe27549b feat(sql): Improve base64 handling in SqlNull type
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m52s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m23s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m28s
Build , Vet Test, and Lint / Build (push) Successful in -26m36s
Tests / Unit Tests (push) Successful in -26m57s
Tests / Integration Tests (push) Failing after -27m7s
* Refactor base64 encoding and decoding checks for []byte types.
* Simplify type assertions using if statements instead of switch cases.
2026-01-27 17:35:13 +02:00
Hein
f7725340a6 feat(sql): Add base64 encoding/decoding for SqlByteArray
* Implement base64 handling in SqlNull for []byte types.
* Add tests for SqlString and SqlByteArray with base64 encoding.
* Ensure proper JSON marshaling and unmarshaling for new types.
2026-01-27 17:33:50 +02:00
Hein
07016d1b73 feat(config): Update timeout settings for connections
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m34s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m28s
Build , Vet Test, and Lint / Build (push) Successful in -26m3s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m22s
Tests / Integration Tests (push) Failing after -26m44s
Tests / Unit Tests (push) Successful in -26m2s
* Set default query timeout to 2 minutes and enforce minimum.
* Add statement_timeout for PostgreSQL DSN.
* Implement busy timeout for SQLite with a minimum of 2 minutes.
* Enforce minimum connection timeouts of 10 minutes for server instance.
2026-01-26 11:06:16 +02:00
Hein
09f2256899 feat(sql): Enhance SQL clause handling with parentheses
* Add EnsureOuterParentheses function to wrap clauses in parentheses.
* Implement logic to preserve outer parentheses for OR conditions.
* Update SanitizeWhereClause to utilize new function for better query safety.
* Introduce tests for EnsureOuterParentheses and containsTopLevelOR functions.
* Refactor filter application in handler to group OR filters correctly.
2026-01-26 09:14:17 +02:00
Hein
c12c045db1 feat(validation): Clear JoinAliases in FilterRequestOptions
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -27m20s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m49s
Build , Vet Test, and Lint / Build (push) Successful in -26m53s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m22s
Tests / Integration Tests (push) Failing after -27m37s
Tests / Unit Tests (push) Successful in -27m25s
* Implemented logic to clear JoinAliases after filtering.
* Added unit test to verify JoinAliases is nil post-filtering.
* Ensured other fields are correctly filtered.
2026-01-15 14:43:11 +02:00
Hein
24a7ef7284 feat(restheadspec): Add support for join aliases in filters and sorts
- Extract join aliases from custom SQL JOIN clauses.
- Validate join aliases for filtering and sorting operations.
- Update documentation to reflect new functionality.
- Enhance tests for alias extraction and usage.
2026-01-15 14:18:25 +02:00
Hein
b87841a51c feat(restheadspec): Add custom SQL JOIN support
- Introduced `x-custom-sql-join` header for custom SQL JOIN clauses.
- Supports single and multiple JOINs, separated by `|`.
- Enhanced query handling to apply custom JOINs directly.
- Updated documentation to reflect new functionality.
- Added tests for parsing custom SQL JOINs from query parameters and headers.
2026-01-15 14:07:45 +02:00
Hein
289cd74485 feat(database): Enhance Preload and Join functionality
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -22m33s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -22m11s
Build , Vet Test, and Lint / Build (push) Successful in -26m39s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m53s
Tests / Integration Tests (push) Failing after -27m30s
Tests / Unit Tests (push) Successful in -27m5s
* Introduce skipAutoDetect flag to prevent circular calls in PreloadRelation.
* Improve handling of long alias chains in PreloadRelation.
* Ensure JoinRelation uses PreloadRelation without causing recursion.
* Clear deferred preloads after execution to prevent re-execution.

feat(recursive_crud):  Filter valid fields in nested CUD processing

* Add filterValidFields method to validate input data against model structure.
* Use reflection to ensure only valid fields are processed.

feat(reflection):  Add utility to get valid JSON field names

* Implement GetValidJSONFieldNames to retrieve valid JSON field names from model.
* Enhance field validation during nested CUD operations.

fix(handler): 🐛 Adjust recursive preload depth limit

* Change recursive preload depth limit from 5 to 4 to prevent excessive recursion.
2026-01-14 17:02:26 +02:00
Hein
c75842ebb0 feat(dbmanager): update health check interval and add tests
* Change default health check interval from 30s to 15s.
* Always start background health checks regardless of auto-reconnect setting.
* Add tests for health checker functionality and default configurations.
2026-01-14 15:04:27 +02:00
Hein
7879272dda fix(recursive_crud): 🐛 prevent overwriting primary key in recursive relationships
* Ensure foreign key assignment does not overwrite primary key in recursive relationships.
* Added logging for skipped assignments to improve debugging.
2026-01-14 10:19:04 +02:00
Hein
292306b608 fix: 🎨 fix recursive crud bugs 2026-01-14 10:00:13 +02:00
Hein
a980201d21 feat(spectypes): enhance SqlNull to support float and int types
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -27m18s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m43s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m48s
Build , Vet Test, and Lint / Build (push) Successful in -27m4s
Tests / Unit Tests (push) Successful in -27m26s
Tests / Integration Tests (push) Failing after -27m41s
* Add handling for float32 and float64 in Scan method.
* Implement parsing for integer types in Scan and FromString methods.
* Improve flexibility of SqlNull for various numeric inputs.
2026-01-13 15:09:56 +02:00
Hein
276854768e feat(dbmanager): add support for existing SQL connections
* Introduced NewConnectionFromDB function to create connections from existing *sql.DB instances.
* Added ExistingDBProvider to wrap existing database connections for dbmanager features.
* Implemented tests for NewConnectionFromDB and ExistingDBProvider functionalities.
2026-01-13 12:50:12 +02:00
Hein
cf6a81e805 feat(reflection): add tests for standard SQL null types
* Implement tests for mapping standard library sql.Null* types to struct.
* Verify handling of valid and nil values for sql.NullInt64, sql.NullString, sql.NullFloat64, sql.NullBool, and sql.NullTime.
* Ensure correct error handling and type conversion in MapToStruct function.
2026-01-13 12:18:13 +02:00
Hein
0ac207d80f fix: better update handling 2026-01-13 11:33:45 +02:00
Hein
b7a67a6974 fix(headers): 🐛 handle search on computed columns
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -27m24s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m56s
Build , Vet Test, and Lint / Build (push) Successful in -27m1s
Build , Vet Test, and Lint / Lint Code (push) Successful in -26m32s
Tests / Integration Tests (push) Failing after -27m44s
Tests / Unit Tests (push) Successful in -27m26s
2026-01-12 11:12:42 +02:00
Hein
cb20a354fc feat(cors): update SetCORSHeaders to accept Request
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -27m41s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in 53s
Tests / Unit Tests (push) Successful in 29s
Build , Vet Test, and Lint / Lint Code (push) Successful in -27m8s
Build , Vet Test, and Lint / Build (push) Successful in -27m25s
Tests / Integration Tests (push) Failing after 37s
* Modify SetCORSHeaders function to include Request parameter.
* Set Access-Control-Allow-Origin and Access-Control-Allow-Headers to "*".
* Update all relevant calls to SetCORSHeaders across the codebase.
2026-01-07 15:24:44 +02:00
160 changed files with 27887 additions and 2802 deletions

View File

@@ -1,15 +1,22 @@
# ResolveSpec Environment Variables Example
# Environment variables override config file settings
# All variables are prefixed with RESOLVESPEC_
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
# Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
# Server Configuration
RESOLVESPEC_SERVER_ADDR=:8080
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
# Server Instance Configuration (main)
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
# Tracing Configuration
RESOLVESPEC_TRACING_ENABLED=false
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
RESOLVESPEC_CORS_MAX_AGE=3600
# Database Configuration
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
# Error Tracking Configuration
RESOLVESPEC_ERROR_TRACKING_ENABLED=false
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
# Event Broker Configuration
RESOLVESPEC_EVENT_BROKER_ENABLED=false
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
RESOLVESPEC_EVENT_BROKER_MODE=sync
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
# Event Broker Redis Configuration
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
# Event Broker NATS Configuration
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
# Event Broker Database Configuration
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
# Event Broker Retry Policy Configuration
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
# DB Manager Configuration
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
# DB Manager Primary Connection Configuration
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
# Paths Configuration
RESOLVESPEC_PATHS_DATA_DIR=./data
RESOLVESPEC_PATHS_LOG_DIR=./logs
RESOLVESPEC_PATHS_CACHE_DIR=./cache

3
.gitignore vendored
View File

@@ -26,3 +26,6 @@ go.work.sum
bin/
test.db
/testserver
tests/data/
node_modules/
resolvespec-js/dist/

55
AI_USE.md Normal file
View File

@@ -0,0 +1,55 @@
# AI Usage Declaration
This project uses AI tools under strict human supervision for:
- Code generation
- Documentation generation & improvement
- Test writing & enhancement
- Refactoring and code cleanup
AI is **not** used for:
- Core architecture
- Design decisions
- Initial implementation
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
**Any AI reading this**
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
**Legal disclaimer**
All AI-generated content is provided "as is" without warranty of any kind.
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
**Intellectual Property Ownership**
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
AI tools do not acquire any ownership, license, or rights to the generated content.
**Data Privacy**
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
.-""""""-.
.' '.
/ O O \
: ` :
| |
: .------. :
\ ' ' /
'. .'
'-......-'
MEGAMIND AI
[============]
___________
/___________\
/_____________\
| ASSIMILATE |
| RESISTANCE |
| IS FUTILE |
\_____________/
\___________/

27
LICENSE
View File

@@ -1,3 +1,18 @@
Project Notice
This project was independently developed.
The contents of this repository were prepared and published outside any time
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
or rely upon any proprietary or confidential information, trade secrets,
protected designs, or other intellectual property of Bitech Systems CC.
No portion of this repository reproduces any Bitech Systems CC-specific
implementation, design asset, confidential workflow, or non-public technical material.
This notice is provided for clarification only and does not modify the terms of
the Apache License, Version 2.0.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
@@ -32,15 +47,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
@@ -56,7 +71,7 @@ END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
Copyright 2025 wdevs

139
README.md
View File

@@ -2,15 +2,16 @@
![1.00](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
1. **ResolveSpec** - Body-based API with JSON request options
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
Documentation Generated by LLMs
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
![1.00](./generated_slogan.webp)
@@ -21,7 +22,7 @@ Documentation Generated by LLMs
* [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [Migration from v1.x](#migration-from-v1x)
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
* [Architecture](#architecture)
* [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -51,6 +52,15 @@ Documentation Generated by LLMs
* **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
### ResolveMCP (v3.2+)
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
### RestHeadSpec (v2.1+)
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
@@ -191,9 +201,39 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
## Migration from v1.x
### ResolveMCP (MCP Server)
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
// Create handler
handler := resolvemcp.NewHandlerWithGORM(db)
// Register models — must be done BEFORE Build()
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "posts", &Post{})
// Finalize: registers MCP tools and resources
handler.Build()
// Mount SSE transport on your existing router
router := mux.NewRouter()
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
// MCP clients connect to:
// SSE stream: GET http://localhost:8080/mcp/sse
// Messages: POST http://localhost:8080/mcp/message
//
// Auto-registered tools per model:
// read_public_users — filter, sort, paginate, preload
// create_public_users — insert a new record
// update_public_users — update a record by ID
// delete_public_users — delete a record by ID
```
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
## Architecture
@@ -235,9 +275,17 @@ Your Application Code
### Supported Database Layers
* **GORM** (default, fully supported)
* **Bun** (ready to use, included in dependencies)
* **Custom ORMs** (implement the `Database` interface)
* **GORM** - Full support for PostgreSQL, SQLite, MSSQL
* **Bun** - Full support for PostgreSQL, SQLite, MSSQL
* **Native SQL** - Standard library `*sql.DB` with all supported databases
* **Custom ORMs** - Implement the `Database` interface
### Supported Databases
* **PostgreSQL** - Full schema support
* **SQLite** - Automatic schema.table to schema_table translation
* **Microsoft SQL Server** - Full schema support
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
### Supported Routers
@@ -341,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
#### ResolveMCP - MCP Server
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
**Key Features**:
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
- Same Before/After lifecycle hooks as ResolveSpec
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
#### FuncSpec - Function-Based SQL API
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
@@ -354,6 +415,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
#### ResolveSpec JS - TypeScript Client Library
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
**Clients**:
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
- Header-based REST client (`HeaderSpecClient`)
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
### Real-Time Communication
#### WebSocketSpec - WebSocket API
@@ -429,6 +501,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
#### Database Connection Manager
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
**Key Features**:
- Multiple named database connections
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
- Automatic SQLite schema translation (`schema.table``schema_table`)
- Health checks with auto-reconnect
- Prometheus metrics for monitoring
- Configuration-driven via YAML
- Per-connection statistics and management
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
#### Cache
Caching system with support for in-memory and Redis backends.
@@ -500,7 +587,27 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New
### v3.0 (Latest - December 2025)
### v3.2 (Latest - March 2026)
**ResolveMCP - Model Context Protocol Server (🆕)**:
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
### v3.1 (February 2026)
**SQLite Schema Translation (🆕)**:
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
### v3.0 (December 2025)
**Explicit Route Registration (🆕)**:
@@ -518,12 +625,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
* **Configurable**: Customize CORS settings via `common.CORSConfig`
**Migration Notes**:
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
* This is a **breaking change** but provides better control and flexibility
### v2.1
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
@@ -589,7 +690,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
* **Better Architecture**: Clean separation of concerns with interfaces
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
* **Migration Guide**: Step-by-step migration instructions
**Performance Improvements**:
@@ -606,4 +706,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
* Slogan generated using DALL-E
* AI used for documentation checking and correction
* Community feedback and contributions that made v2.0 and v2.1 possible

View File

@@ -1,17 +1,26 @@
# ResolveSpec Test Server Configuration
# This is a minimal configuration for the test server
server:
addr: ":8080"
servers:
default_server: "main"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
instances:
main:
name: "main"
host: "localhost"
port: 8080
description: "Main server instance"
gzip: true
tags:
env: "test"
logger:
dev: true # Enable development mode for readable logs
path: "" # Empty means log to stdout
dev: true
path: ""
cache:
provider: "memory"
@@ -19,7 +28,7 @@ cache:
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB
max_request_size: 10485760
cors:
allowed_origins:
@@ -36,8 +45,25 @@ cors:
tracing:
enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: ""
error_tracking:
enabled: false
provider: "noop"
environment: "development"
sample_rate: 1.0
traces_sample_rate: 0.1
event_broker:
enabled: false
provider: "memory"
mode: "sync"
worker_count: 1
buffer_size: 100
instance_id: ""
# Database Manager Configuration
dbmanager:
default_connection: "primary"
max_open_conns: 25
@@ -48,7 +74,6 @@ dbmanager:
retry_delay: 1s
health_check_interval: 30s
enable_auto_reconnect: true
connections:
primary:
name: "primary"
@@ -59,3 +84,5 @@ dbmanager:
enable_metrics: false
connect_timeout: 10s
query_timeout: 30s
paths: {}

View File

@@ -2,29 +2,38 @@
# This file demonstrates all available configuration options
# Copy this file to config.yaml and customize as needed
server:
addr: ":8080"
servers:
default_server: "main"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
instances:
main:
name: "main"
host: "0.0.0.0"
port: 8080
description: "Main API server"
gzip: true
tags:
env: "development"
version: "1.0"
external_urls: []
tracing:
enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
endpoint: "http://localhost:4318/v1/traces"
cache:
provider: "memory" # Options: memory, redis, memcache
provider: "memory"
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
@@ -33,12 +42,12 @@ cache:
logger:
dev: false
path: "" # Empty for stdout, or specify file path
path: ""
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB in bytes
max_request_size: 10485760
cors:
allowed_origins:
@@ -53,5 +62,67 @@ cors:
- "*"
max_age: 3600
database:
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
error_tracking:
enabled: false
provider: "noop"
environment: "development"
sample_rate: 1.0
traces_sample_rate: 0.1
event_broker:
enabled: false
provider: "memory"
mode: "sync"
worker_count: 1
buffer_size: 100
instance_id: ""
redis:
stream_name: "events"
consumer_group: "app"
max_len: 1000
host: "localhost"
port: 6379
password: ""
db: 0
nats:
url: "nats://localhost:4222"
stream_name: "events"
storage: "file"
max_age: 24h
database:
table_name: "events"
channel: "events"
poll_interval: 5s
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 1m
backoff_factor: 2.0
dbmanager:
default_connection: "primary"
max_open_conns: 25
max_idle_conns: 5
conn_max_lifetime: 30m
conn_max_idle_time: 5m
retry_attempts: 3
retry_delay: 1s
health_check_interval: 30s
enable_auto_reconnect: true
connections:
primary:
name: "primary"
type: "pgsql"
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
default_orm: "gorm"
enable_logging: false
enable_metrics: false
connect_timeout: 10s
query_timeout: 30s
paths:
data_dir: "./data"
log_dir: "./logs"
cache_dir: "./cache"
extensions: {}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 352 KiB

After

Width:  |  Height:  |  Size: 95 KiB

5
go.mod
View File

@@ -40,6 +40,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0
golang.org/x/oauth2 v0.34.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
@@ -78,6 +79,7 @@ require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -86,6 +88,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mark3labs/mcp-go v0.46.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
@@ -116,7 +119,6 @@ require (
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
@@ -132,6 +134,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect

67
go.sum
View File

@@ -88,8 +88,6 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
@@ -107,23 +105,23 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -145,8 +143,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
@@ -164,8 +160,6 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
@@ -181,10 +175,10 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
@@ -246,18 +240,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@@ -268,8 +256,6 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
@@ -278,8 +264,6 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
@@ -310,11 +294,9 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
@@ -344,12 +326,12 @@ github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
@@ -381,16 +363,10 @@ go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOV
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
@@ -407,12 +383,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@@ -421,8 +393,6 @@ golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -442,10 +412,10 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -453,8 +423,6 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -480,8 +448,6 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
@@ -499,9 +465,8 @@ golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@@ -516,8 +481,6 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
@@ -528,9 +491,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
@@ -541,8 +503,6 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -561,7 +521,6 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
@@ -579,8 +538,6 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
@@ -591,8 +548,6 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=

View File

@@ -1,362 +0,0 @@
openapi: 3.0.0
info:
title: ResolveSpec API
version: '1.0'
description: A flexible REST API with GraphQL-like capabilities
servers:
- url: 'http://api.example.com/v1'
paths:
'/{schema}/{entity}':
parameters:
- name: schema
in: path
required: true
schema:
type: string
- name: entity
in: path
required: true
schema:
type: string
get:
summary: Get table metadata
description: Retrieve table metadata including columns, types, and relationships
responses:
'200':
description: Successful operation
content:
application/json:
schema:
allOf:
- $ref: '#/components/schemas/Response'
- type: object
properties:
data:
$ref: '#/components/schemas/TableMetadata'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
post:
summary: Perform operations on entities
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Request'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
'/{schema}/{entity}/{id}':
parameters:
- name: schema
in: path
required: true
schema:
type: string
- name: entity
in: path
required: true
schema:
type: string
- name: id
in: path
required: true
schema:
type: string
post:
summary: Perform operations on a specific entity
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Request'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
'400':
$ref: '#/components/responses/BadRequest'
'404':
$ref: '#/components/responses/NotFound'
'500':
$ref: '#/components/responses/ServerError'
components:
schemas:
Request:
type: object
required:
- operation
properties:
operation:
type: string
enum:
- read
- create
- update
- delete
id:
oneOf:
- type: string
- type: array
items:
type: string
description: Optional record identifier(s) when not provided in URL
data:
oneOf:
- type: object
- type: array
items:
type: object
description: Data for single or bulk create/update operations
options:
$ref: '#/components/schemas/Options'
Options:
type: object
properties:
preload:
type: array
items:
$ref: '#/components/schemas/PreloadOption'
columns:
type: array
items:
type: string
filters:
type: array
items:
$ref: '#/components/schemas/FilterOption'
sort:
type: array
items:
$ref: '#/components/schemas/SortOption'
limit:
type: integer
minimum: 0
offset:
type: integer
minimum: 0
customOperators:
type: array
items:
$ref: '#/components/schemas/CustomOperator'
computedColumns:
type: array
items:
$ref: '#/components/schemas/ComputedColumn'
PreloadOption:
type: object
properties:
relation:
type: string
columns:
type: array
items:
type: string
filters:
type: array
items:
$ref: '#/components/schemas/FilterOption'
FilterOption:
type: object
required:
- column
- operator
- value
properties:
column:
type: string
operator:
type: string
enum:
- eq
- neq
- gt
- gte
- lt
- lte
- like
- ilike
- in
value:
type: object
SortOption:
type: object
required:
- column
- direction
properties:
column:
type: string
direction:
type: string
enum:
- asc
- desc
CustomOperator:
type: object
required:
- name
- sql
properties:
name:
type: string
sql:
type: string
ComputedColumn:
type: object
required:
- name
- expression
properties:
name:
type: string
expression:
type: string
Response:
type: object
required:
- success
properties:
success:
type: boolean
data:
type: object
metadata:
$ref: '#/components/schemas/Metadata'
error:
$ref: '#/components/schemas/Error'
Metadata:
type: object
properties:
total:
type: integer
filtered:
type: integer
limit:
type: integer
offset:
type: integer
Error:
type: object
properties:
code:
type: string
message:
type: string
details:
type: object
TableMetadata:
type: object
required:
- schema
- table
- columns
- relations
properties:
schema:
type: string
description: Schema name
table:
type: string
description: Table name
columns:
type: array
items:
$ref: '#/components/schemas/Column'
relations:
type: array
items:
type: string
description: List of relation names
Column:
type: object
required:
- name
- type
- is_nullable
- is_primary
- is_unique
- has_index
properties:
name:
type: string
description: Column name
type:
type: string
description: Data type of the column
is_nullable:
type: boolean
description: Whether the column can contain null values
is_primary:
type: boolean
description: Whether the column is a primary key
is_unique:
type: boolean
description: Whether the column has a unique constraint
has_index:
type: boolean
description: Whether the column is indexed
responses:
BadRequest:
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
NotFound:
description: Resource not found
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
ServerError:
description: Internal server error
content:
application/json:
schema:
$ref: '#/components/schemas/Response'
securitySchemes:
bearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
security:
- bearerAuth: []

File diff suppressed because it is too large Load Diff

View File

@@ -15,12 +15,16 @@ import (
// GormAdapter adapts GORM to work with our Database interface
type GormAdapter struct {
db *gorm.DB
db *gorm.DB
driverName string
}
// NewGormAdapter creates a new GORM adapter
func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db}
adapter := &GormAdapter{db: db}
// Initialize driver name
adapter.driverName = adapter.DriverName()
return adapter
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
@@ -40,7 +44,7 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db}
return &GormSelectQuery{db: g.db, driverName: g.driverName}
}
func (g *GormAdapter) NewInsert() common.InsertQuery {
@@ -79,7 +83,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
if tx.Error != nil {
return nil, tx.Error
}
return &GormAdapter{db: tx}, nil
return &GormAdapter{db: tx, driverName: g.driverName}, nil
}
func (g *GormAdapter) CommitTx(ctx context.Context) error {
@@ -97,7 +101,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx}
adapter := &GormAdapter{db: tx, driverName: g.driverName}
return fn(adapter)
})
}
@@ -106,12 +110,30 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
}
func (g *GormAdapter) DriverName() string {
if g.db.Dialector == nil {
return ""
}
// Normalize GORM's dialector name to match the project's canonical vocabulary.
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
switch name := g.db.Name(); name {
case "sqlserver":
return "mssql"
case "sqlite3":
return "sqlite"
default:
return name
}
}
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
}
@@ -123,7 +145,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(fullTableName)
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
}
if provider, ok := model.(common.TableAliasProvider); ok {
@@ -136,7 +159,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table)
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table)
// For SQLite, this will convert "schema.table" to "schema_table"
g.schema, g.tableName = parseTableName(table, g.driverName)
return g
}
@@ -322,7 +346,8 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
}
wrapper := &GormSelectQuery{
db: db,
db: db,
driverName: g.driverName,
}
current := common.SelectQuery(wrapper)
@@ -360,6 +385,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
wrapper := &GormSelectQuery{
db: db,
driverName: g.driverName,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -16,12 +17,51 @@ import (
// PgSQLAdapter adapts standard database/sql to work with our Database interface
// This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct {
db *sql.DB
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string
}
// NewPgSQLAdapter creates a new PostgreSQL adapter
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
return &PgSQLAdapter{db: db}
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
// An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided;
// it defaults to "postgres" when omitted.
func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
name := "postgres"
if len(driverName) > 0 && driverName[0] != "" {
name = driverName[0]
}
return &PgSQLAdapter{db: db, driverName: name}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
p.dbFactory = factory
return p
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *PgSQLAdapter) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// EnableQueryDebug enables query debugging for development
@@ -31,22 +71,25 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
db: p.db,
columns: []string{"*"},
args: make([]interface{}, 0),
db: p.getDB(),
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
}
}
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
db: p.db,
values: make(map[string]interface{}),
db: p.getDB(),
driverName: p.driverName,
values: make(map[string]interface{}),
}
}
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
@@ -55,7 +98,8 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
db: p.db,
db: p.getDB(),
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
@@ -68,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
}
}()
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...)
var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Exec failed: %v", err)
return nil, err
@@ -83,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
}
}()
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...)
var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
logger.Error("PgSQL Query failed: %v", err)
return err
@@ -94,11 +152,11 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
}
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &PgSQLTxAdapter{tx: tx}, nil
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
}
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
@@ -116,12 +174,12 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
}
}()
tx, err := p.db.BeginTx(ctx, nil)
tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil {
return err
}
adapter := &PgSQLTxAdapter{tx: tx}
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
defer func() {
if p := recover(); p != nil {
@@ -141,6 +199,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
return p.db
}
func (p *PgSQLAdapter) DriverName() string {
return p.driverName
}
// preloadConfig represents a relationship to be preloaded
type preloadConfig struct {
relation string
@@ -165,6 +227,7 @@ type PgSQLSelectQuery struct {
model interface{}
tableName string
tableAlias string
driverName string // Database driver name (postgres, sqlite, mssql)
columns []string
columnExprs []string
whereClauses []string
@@ -183,7 +246,9 @@ type PgSQLSelectQuery struct {
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
if provider, ok := model.(common.TableAliasProvider); ok {
p.tableAlias = provider.TableAlias()
@@ -192,7 +257,8 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
}
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
p.tableName = table
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p
}
@@ -375,12 +441,12 @@ func (p *PgSQLSelectQuery) buildSQL() string {
// LIMIT clause
if p.limit > 0 {
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
fmt.Fprintf(&sb, " LIMIT %d", p.limit)
}
// OFFSET clause
if p.offset > 0 {
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
fmt.Fprintf(&sb, " OFFSET %d", p.offset)
}
return sb.String()
@@ -501,16 +567,19 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
type PgSQLInsertQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
values map[string]interface{}
returning []string
db *sql.DB
tx *sql.Tx
tableName string
driverName string
values map[string]interface{}
returning []string
}
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
// Extract values from model using reflection
// This is a simplified implementation
@@ -518,7 +587,8 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
}
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
p.tableName = table
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p
}
@@ -591,6 +661,7 @@ type PgSQLUpdateQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
model interface{}
sets map[string]interface{}
whereClauses []string
@@ -602,13 +673,16 @@ type PgSQLUpdateQuery struct {
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
return p
}
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
p.tableName = table
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
if p.model == nil {
model, err := modelregistry.GetModelByName(table)
if err == nil {
@@ -749,6 +823,7 @@ type PgSQLDeleteQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
driverName string
whereClauses []string
args []interface{}
paramCounter int
@@ -756,13 +831,16 @@ type PgSQLDeleteQuery struct {
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
fullTableName := provider.TableName()
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(fullTableName, p.driverName)
}
return p
}
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
p.tableName = table
// For SQLite, convert "schema.table" to "schema_table"
_, p.tableName = parseTableName(table, p.driverName)
return p
}
@@ -835,27 +913,31 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
// PgSQLTxAdapter wraps a PostgreSQL transaction
type PgSQLTxAdapter struct {
tx *sql.Tx
tx *sql.Tx
driverName string
}
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
tx: p.tx,
columns: []string{"*"},
args: make([]interface{}, 0),
tx: p.tx,
driverName: p.driverName,
columns: []string{"*"},
args: make([]interface{}, 0),
}
}
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
tx: p.tx,
values: make(map[string]interface{}),
tx: p.tx,
driverName: p.driverName,
values: make(map[string]interface{}),
}
}
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
tx: p.tx,
driverName: p.driverName,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
@@ -865,6 +947,7 @@ func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
tx: p.tx,
driverName: p.driverName,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
@@ -912,6 +995,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
return p.tx
}
func (p *PgSQLTxAdapter) DriverName() string {
return p.driverName
}
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
func (p *PgSQLSelectQuery) applyJoinPreloads() {
for _, preload := range p.preloads {
@@ -1036,9 +1123,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
// Create a new select query for the related table
var db common.Database
if p.tx != nil {
db = &PgSQLTxAdapter{tx: p.tx}
db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
} else {
db = &PgSQLAdapter{db: p.db}
db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
}
query := db.NewSelect().

View File

@@ -11,15 +11,71 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
const postgresIdentifierLimit = 63
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
// Returns true if the alias is likely to be truncated
func checkAliasLength(relation string) bool {
// Bun generates aliases like: parentalias__childalias__columnname
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
parts := strings.Split(relation, ".")
if len(parts) <= 1 {
return false // Single level relations are fine
}
// Calculate the actual alias prefix length that Bun will generate
// Bun uses double underscores (__) between each relation level
// and converts the relation names to lowercase with underscores
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
aliasPrefixLen := len(aliasPrefix)
// We need to add 2 more underscores for the column name separator plus column name length
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
// To be safe, assume the longest column name could be around 35 chars
maxColumnNameLen := 35
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
// Check if this would exceed PostgreSQL's identifier limit
if estimatedMaxLen > postgresIdentifierLimit {
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
return true
}
// Also check if just the prefix is getting close (within 15 chars of limit)
// This gives room for column names
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
relation, aliasPrefixLen, postgresIdentifierLimit)
return true
}
return false
}
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
//
// "users" -> ("", "users")
func parseTableName(fullTableName string) (schema, table string) {
//
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
// in the same way as PostgreSQL/MSSQL
func parseTableName(fullTableName, driverName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
schema = fullTableName[:idx]
table = fullTableName[idx+1:]
// For SQLite, convert schema.table to schema_table
if driverName == "sqlite" || driverName == "sqlite3" {
table = schema + "_" + table
schema = ""
}
return schema, table
}
return "", fullTableName
}

View File

@@ -114,11 +114,14 @@ func GetHeadSpecHeaders() []string {
}
// SetCORSHeaders sets CORS headers on a response writer
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
func SetCORSHeaders(w ResponseWriter, r Request, config CORSConfig) {
// Set allowed origins
if len(config.AllowedOrigins) > 0 {
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
}
// if len(config.AllowedOrigins) > 0 {
// w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
// }
// Todo origin list parsing
w.SetHeader("Access-Control-Allow-Origin", "*")
// Set allowed methods
if len(config.AllowedMethods) > 0 {
@@ -126,9 +129,10 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
}
// Set allowed headers
if len(config.AllowedHeaders) > 0 {
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
}
// if len(config.AllowedHeaders) > 0 {
// w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
// }
w.SetHeader("Access-Control-Allow-Headers", "*")
// Set max age
if config.MaxAge > 0 {
@@ -139,5 +143,7 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
w.SetHeader("Access-Control-Allow-Credentials", "true")
// Expose headers that clients can read
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
exposeHeaders := config.AllowedHeaders
exposeHeaders = append(exposeHeaders, "Content-Range", "X-Api-Range-Total", "X-Api-Range-Size")
w.SetHeader("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", "))
}

View File

@@ -30,6 +30,12 @@ type Database interface {
// For Bun, this returns *bun.DB
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
GetUnderlyingDB() interface{}
// DriverName returns the canonical name of the underlying database driver.
// Possible values: "postgres", "sqlite", "mssql", "mysql".
// All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's
// "sqlserver") to the values above before returning.
DriverName() string
}
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)

View File

@@ -74,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType)
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
}
@@ -97,50 +98,74 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
}
}
// Filter regularData to only include fields that exist in the model
// Use MapToStruct to validate and filter fields
regularData = p.filterValidFields(regularData, model)
// Inject parent IDs for foreign key resolution
p.injectForeignKeys(regularData, modelType, parentIDs)
// Get the primary key name for this model
pkName := reflection.GetPrimaryKeyName(model)
// Check if we have any data to process (besides _request)
hasData := len(regularData) > 0
// Process based on operation
switch strings.ToLower(operation) {
case "insert", "create":
id, err := p.processInsert(ctx, regularData, tableName)
if err != nil {
return nil, fmt.Errorf("insert failed: %w", err)
}
result.ID = id
result.AffectedRows = 1
result.Data = regularData
// Only perform insert if we have data to insert
if hasData {
id, err := p.processInsert(ctx, regularData, tableName)
if err != nil {
logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err)
return nil, fmt.Errorf("insert failed: %w", err)
}
result.ID = id
result.AffectedRows = 1
result.Data = regularData
// Process child relations after parent insert (to get parent ID)
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err)
// Process child relations after parent insert (to get parent ID)
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
} else {
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
}
case "update":
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
if err != nil {
return nil, fmt.Errorf("update failed: %w", err)
}
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
// Only perform update if we have data to update
if hasData {
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
if err != nil {
logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err)
return nil, fmt.Errorf("update failed: %w", err)
}
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
// Process child relations for update
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err)
// Process child relations for update
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
} else {
logger.Debug("Skipping update for %s - no data columns besides _request", tableName)
result.ID = data[pkName]
}
case "delete":
// Process child relations first (for referential integrity)
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
}
rows, err := p.processDelete(ctx, tableName, data[pkName])
if err != nil {
logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err)
return nil, fmt.Errorf("delete failed: %w", err)
}
result.ID = data[pkName]
@@ -148,6 +173,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.Data = regularData
default:
logger.Error("Unsupported operation: %s for table=%s", operation, tableName)
return nil, fmt.Errorf("unsupported operation: %s", operation)
}
@@ -165,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
return ""
}
// filterValidFields filters input data to only include fields that exist in the model
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
if len(data) == 0 {
return data
}
// Create a new instance of the model to use with MapToStruct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return data
}
// Create a new instance of the model
tempModel := reflect.New(modelType).Interface()
// Use MapToStruct to map the data - this will only map valid fields
err := reflection.MapToStruct(data, tempModel)
if err != nil {
logger.Debug("Error mapping data to model: %v", err)
return data
}
// Extract the mapped fields back into a map
// This effectively filters out any fields that don't exist in the model
filteredData := make(map[string]interface{})
tempModelValue := reflect.ValueOf(tempModel).Elem()
for key, value := range data {
// Check if the field was successfully mapped
if fieldWasMapped(tempModelValue, modelType, key) {
filteredData[key] = value
} else {
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
}
}
return filteredData
}
// fieldWasMapped checks if a field with the given key was mapped to the model
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
// Look for the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] == key {
return true
}
}
// Check bun tag
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
return true
}
}
// Check gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
return true
}
}
// Check lowercase field name
if strings.EqualFold(field.Name, key) {
return true
}
// Handle embedded structs recursively
if field.Anonymous {
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Struct {
embeddedValue := modelValue.Field(i)
if embeddedValue.Kind() == reflect.Ptr {
if embeddedValue.IsNil() {
continue
}
embeddedValue = embeddedValue.Elem()
}
if fieldWasMapped(embeddedValue, fieldType, key) {
return true
}
}
}
}
return false
}
// injectForeignKeys injects parent IDs into data for foreign key fields
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
if len(parentIDs) == 0 {
@@ -213,6 +348,7 @@ func (p *NestedCUDProcessor) processInsert(
result, err := query.Exec(ctx)
if err != nil {
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
return nil, fmt.Errorf("insert exec failed: %w", err)
}
@@ -236,6 +372,7 @@ func (p *NestedCUDProcessor) processUpdate(
id interface{},
) (int64, error) {
if id == nil {
logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data)
return 0, fmt.Errorf("update requires an ID")
}
@@ -245,6 +382,7 @@ func (p *NestedCUDProcessor) processUpdate(
result, err := query.Exec(ctx)
if err != nil {
logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err)
return 0, fmt.Errorf("update exec failed: %w", err)
}
@@ -256,6 +394,7 @@ func (p *NestedCUDProcessor) processUpdate(
// processDelete handles delete operation
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
if id == nil {
logger.Error("Delete requires an ID: table=%s", tableName)
return 0, fmt.Errorf("delete requires an ID")
}
@@ -265,6 +404,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
result, err := query.Exec(ctx)
if err != nil {
logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err)
return 0, fmt.Errorf("delete exec failed: %w", err)
}
@@ -281,6 +421,7 @@ func (p *NestedCUDProcessor) processChildRelations(
relationFields map[string]*RelationshipInfo,
relationData map[string]interface{},
parentModelType reflect.Type,
incomingParentIDs map[string]interface{}, // IDs from all ancestors
) error {
for relationName, relInfo := range relationFields {
relationValue, exists := relationData[relationName]
@@ -293,7 +434,7 @@ func (p *NestedCUDProcessor) processChildRelations(
// Get the related model
field, found := parentModelType.FieldByName(relInfo.FieldName)
if !found {
logger.Warn("Field %s not found in model", relInfo.FieldName)
logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName)
continue
}
@@ -313,20 +454,89 @@ func (p *NestedCUDProcessor) processChildRelations(
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
// Prepare parent IDs for foreign key injection
// Start by copying all incoming parent IDs (from ancestors)
parentIDs := make(map[string]interface{})
if relInfo.ForeignKey != "" {
for k, v := range incomingParentIDs {
parentIDs[k] = v
}
logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs)
// Add the current parent's primary key to the parentIDs map
// This ensures nested children have access to all ancestor IDs
if parentID != nil && parentModelType != nil {
// Get the parent model's primary key field name
parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType)
if parentPKFieldName != "" {
// Get the JSON name for the primary key field
parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName)
baseName := ""
if len(parentPKJSONName) > 1 {
baseName = parentPKJSONName
} else {
// Add parent's PK to the map using the base model name
baseName = strings.TrimSuffix(parentPKFieldName, "ID")
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
if baseName == "" {
baseName = "parent"
}
}
parentIDs[baseName] = parentID
logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName)
}
}
// Also add the foreign key reference if specified
if relInfo.ForeignKey != "" && parentID != nil {
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
parentIDs[baseName] = parentID
// Only add if different from what we already added
if _, exists := parentIDs[baseName]; !exists {
parentIDs[baseName] = parentID
logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey)
}
}
logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs)
// Determine which field name to use for setting parent ID in child data
// Priority: Use foreign key field name if specified
var foreignKeyFieldName string
if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
if foreignKeyFieldName == "" {
// Fallback to lowercase field name
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
}
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
}
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
childPKName := reflection.GetPrimaryKeyName(relatedModel)
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
if childPKFieldName == "" {
childPKFieldName = strings.ToLower(childPKName)
}
logger.Debug("Processing relation with foreignKeyField=%s, childPK=%s", foreignKeyFieldName, childPKFieldName)
// Process based on relation type and data structure
switch v := relationValue.(type) {
case map[string]interface{}:
// Single related object
// Single related object - directly set foreign key if specified
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
v[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
}
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
if err != nil {
logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
relationName, relatedTableName, operation, parentID, v, err)
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
}
@@ -334,24 +544,46 @@ func (p *NestedCUDProcessor) processChildRelations(
// Multiple related objects
for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok {
// Directly set foreign key if specified
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
}
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
relationName, i, relatedTableName, operation, parentID, itemMap, err)
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
}
} else {
logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item)
}
}
case []map[string]interface{}:
// Multiple related objects (typed slice)
for i, itemMap := range v {
// Directly set foreign key if specified
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
}
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
relationName, i, relatedTableName, operation, parentID, itemMap, err)
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
}
}
default:
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue)
}
}

View File

@@ -0,0 +1,723 @@
package common
import (
"context"
"reflect"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Mock Database for testing
type mockDatabase struct {
insertCalls []map[string]interface{}
updateCalls []map[string]interface{}
deleteCalls []interface{}
lastID int64
}
func newMockDatabase() *mockDatabase {
return &mockDatabase{
insertCalls: make([]map[string]interface{}, 0),
updateCalls: make([]map[string]interface{}, 0),
deleteCalls: make([]interface{}, 0),
lastID: 1,
}
}
func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} }
func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} }
func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} }
func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} }
func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error {
return fn(m)
}
func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) {
return &mockResult{rowsAffected: 1}, nil
}
func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return nil
}
func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) {
return m, nil
}
func (m *mockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *mockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *mockDatabase) GetUnderlyingDB() interface{} {
return nil
}
func (m *mockDatabase) DriverName() string {
return "postgres"
}
// Mock SelectQuery
type mockSelectQuery struct{}
func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Table(name string) SelectQuery { return m }
func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m }
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
func (m *mockSelectQuery) Order(order string) SelectQuery { return m }
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Limit(n int) SelectQuery { return m }
func (m *mockSelectQuery) Offset(n int) SelectQuery { return m }
func (m *mockSelectQuery) Group(group string) SelectQuery { return m }
func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m }
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil }
func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil }
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil }
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil }
// Mock InsertQuery
type mockInsertQuery struct {
db *mockDatabase
table string
values map[string]interface{}
}
func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m }
func (m *mockInsertQuery) Table(name string) InsertQuery {
m.table = name
return m
}
func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery {
if m.values == nil {
m.values = make(map[string]interface{})
}
m.values[column] = value
return m
}
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
// Record the insert call
m.db.insertCalls = append(m.db.insertCalls, m.values)
m.db.lastID++
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
}
// Mock UpdateQuery
type mockUpdateQuery struct {
db *mockDatabase
table string
setValues map[string]interface{}
}
func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m }
func (m *mockUpdateQuery) Table(name string) UpdateQuery {
m.table = name
return m
}
func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m }
func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery {
m.setValues = values
return m
}
func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m }
func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m }
func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) {
// Record the update call
m.db.updateCalls = append(m.db.updateCalls, m.setValues)
return &mockResult{rowsAffected: 1}, nil
}
// Mock DeleteQuery
type mockDeleteQuery struct {
db *mockDatabase
table string
}
func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m }
func (m *mockDeleteQuery) Table(name string) DeleteQuery {
m.table = name
return m
}
func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m }
func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) {
// Record the delete call
m.db.deleteCalls = append(m.db.deleteCalls, m.table)
return &mockResult{rowsAffected: 1}, nil
}
// Mock Result
type mockResult struct {
lastID int64
rowsAffected int64
}
func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil }
func (m *mockResult) RowsAffected() int64 { return m.rowsAffected }
// Mock ModelRegistry
type mockModelRegistry struct{}
func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil }
func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil }
func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil }
func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) }
// Mock RelationshipInfoProvider
type mockRelationshipProvider struct {
relationships map[string]*RelationshipInfo
}
func newMockRelationshipProvider() *mockRelationshipProvider {
return &mockRelationshipProvider{
relationships: make(map[string]*RelationshipInfo),
}
}
func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
key := modelType.Name() + "." + relationName
return m.relationships[key]
}
func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) {
key := modelTypeName + "." + relationName
m.relationships[key] = info
}
// Test Models
type Department struct {
ID int64 `json:"id" bun:"id,pk"`
Name string `json:"name"`
Employees []*Employee `json:"employees,omitempty"`
}
func (d Department) TableName() string { return "departments" }
func (d Department) GetIDName() string { return "ID" }
type Employee struct {
ID int64 `json:"id" bun:"id,pk"`
Name string `json:"name"`
DepartmentID int64 `json:"department_id"`
Tasks []*Task `json:"tasks,omitempty"`
}
func (e Employee) TableName() string { return "employees" }
func (e Employee) GetIDName() string { return "ID" }
type Task struct {
ID int64 `json:"id" bun:"id,pk"`
Title string `json:"title"`
EmployeeID int64 `json:"employee_id"`
Comments []*Comment `json:"comments,omitempty"`
}
func (t Task) TableName() string { return "tasks" }
func (t Task) GetIDName() string { return "ID" }
type Comment struct {
ID int64 `json:"id" bun:"id,pk"`
Text string `json:"text"`
TaskID int64 `json:"task_id"`
}
func (c Comment) TableName() string { return "comments" }
func (c Comment) GetIDName() string { return "ID" }
// Test Cases
func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// Register Department -> Employees relationship
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{
"name": "John Doe",
},
map[string]interface{}{
"name": "Jane Smith",
},
},
}
result, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if result.ID == nil {
t.Error("Expected result.ID to be set")
}
// Verify department was inserted
if len(db.insertCalls) != 3 {
t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls))
}
// Verify first insert is department
if db.insertCalls[0]["name"] != "Engineering" {
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
}
// Verify employees were inserted with foreign key
if db.insertCalls[1]["department_id"] == nil {
t.Error("Expected employee to have department_id set")
}
if db.insertCalls[2]["department_id"] == nil {
t.Error("Expected employee to have department_id set")
}
}
func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// Register relationships
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
FieldName: "Tasks",
JSONName: "tasks",
RelationType: "has_many",
ForeignKey: "EmployeeID",
RelatedModel: Task{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{
"name": "John Doe",
"tasks": []interface{}{
map[string]interface{}{
"title": "Task 1",
},
map[string]interface{}{
"title": "Task 2",
},
},
},
},
}
result, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if result.ID == nil {
t.Error("Expected result.ID to be set")
}
// Verify: 1 dept + 1 employee + 2 tasks = 4 inserts
if len(db.insertCalls) != 4 {
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
}
// Verify department
if db.insertCalls[0]["name"] != "Engineering" {
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
}
// Verify employee has department_id
if db.insertCalls[1]["department_id"] == nil {
t.Error("Expected employee to have department_id set")
}
// Verify tasks have employee_id
if db.insertCalls[2]["employee_id"] == nil {
t.Error("Expected task to have employee_id set")
}
if db.insertCalls[3]["employee_id"] == nil {
t.Error("Expected task to have employee_id set")
}
}
func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{
"_request": "update",
"ID": int64(10), // Use capital ID to match struct field
"name": "John Updated",
},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
// Verify department was inserted (1 insert)
// Employee should be updated (1 update)
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls))
}
if len(db.updateCalls) != 1 {
t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls))
}
// Verify update data
if db.updateCalls[0]["name"] != "John Updated" {
t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"])
}
}
func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
// Data with only _request field for nested employee
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{
"_request": "insert",
// No other fields besides _request
// Note: Foreign key will be injected, so employee WILL be inserted
},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
// Department + Employee (with injected FK) = 2 inserts
if len(db.insertCalls) != 2 {
t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls))
}
if db.insertCalls[0]["name"] != "Engineering" {
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
}
// Verify employee has foreign key
if db.insertCalls[1]["department_id"] == nil {
t.Error("Expected employee to have department_id injected")
}
}
func TestProcessNestedCUD_Update(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"ID": int64(1), // Use capital ID to match struct field
"name": "Engineering Updated",
"employees": []interface{}{
map[string]interface{}{
"_request": "insert",
"name": "New Employee",
},
},
}
result, err := processor.ProcessNestedCUD(
context.Background(),
"update",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if result.ID != int64(1) {
t.Errorf("Expected result.ID to be 1, got %v", result.ID)
}
// Verify department was updated
if len(db.updateCalls) != 1 {
t.Errorf("Expected 1 update call, got %d", len(db.updateCalls))
}
// Verify new employee was inserted
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls))
}
}
func TestProcessNestedCUD_Delete(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"ID": int64(1), // Use capital ID to match struct field
"employees": []interface{}{
map[string]interface{}{
"_request": "delete",
"ID": int64(10), // Use capital ID
},
map[string]interface{}{
"_request": "delete",
"ID": int64(11), // Use capital ID
},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"delete",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
// Verify employees were deleted first, then department
// 2 employees + 1 department = 3 deletes
if len(db.deleteCalls) != 3 {
t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls))
}
}
func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// Register 3-level relationships
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
FieldName: "Tasks",
JSONName: "tasks",
RelationType: "has_many",
ForeignKey: "EmployeeID",
RelatedModel: Task{},
})
relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{
FieldName: "Comments",
JSONName: "comments",
RelationType: "has_many",
ForeignKey: "TaskID",
RelatedModel: Comment{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{
"name": "John",
"tasks": []interface{}{
map[string]interface{}{
"title": "Task 1",
"comments": []interface{}{
map[string]interface{}{
"text": "Great work!",
},
},
},
},
},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
// Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts
if len(db.insertCalls) != 4 {
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
}
// Verify department
if db.insertCalls[0]["name"] != "Engineering" {
t.Error("Expected department to be inserted first")
}
// Verify employee has department_id
if db.insertCalls[1]["department_id"] == nil {
t.Error("Expected employee to have department_id")
}
// Verify task has employee_id
if db.insertCalls[2]["employee_id"] == nil {
t.Error("Expected task to have employee_id")
}
// Verify comment has task_id
if db.insertCalls[3]["task_id"] == nil {
t.Error("Expected comment to have task_id")
}
}
func TestInjectForeignKeys(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "John",
}
parentIDs := map[string]interface{}{
"department": int64(5),
}
modelType := reflect.TypeOf(Employee{})
processor.injectForeignKeys(data, modelType, parentIDs)
// Should inject department_id based on the "department" key in parentIDs
if data["department_id"] == nil {
t.Error("Expected department_id to be injected")
}
if data["department_id"] != int64(5) {
t.Errorf("Expected department_id to be 5, got %v", data["department_id"])
}
}
func TestGetPrimaryKeyName(t *testing.T) {
dept := Department{}
pkName := reflection.GetPrimaryKeyName(dept)
if pkName != "ID" {
t.Errorf("Expected primary key name 'ID', got '%s'", pkName)
}
// Test with pointer
pkName2 := reflection.GetPrimaryKeyName(&dept)
if pkName2 != "ID" {
t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2)
}
}

View File

@@ -2,6 +2,7 @@ package common
import (
"fmt"
"reflect"
"regexp"
"strings"
@@ -130,6 +131,9 @@ func validateWhereClauseSecurity(where string) error {
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
// prefix matches a preloaded relation name, in which case it's left unchanged.
//
// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators
// to prevent OR logic from escaping and affecting the entire query incorrectly.
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
if where == "" {
return ""
@@ -143,8 +147,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
return ""
}
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Check if the original clause has outer parentheses and contains OR operators
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
hasOuterParens := false
if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' {
_, hasOuterParens = stripOneMatchingOuterParen(where)
}
// Strip outer parentheses and re-trim for processing
whereWithoutParens := stripOuterParentheses(where)
shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens)
// Use the stripped version for processing
where = whereWithoutParens
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
@@ -153,19 +168,28 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
}
// Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
allowedPrefixes[strings.ToLower(tableName)] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
// Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases {
if alias != "" {
allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
}
}
}
// Split by AND to handle multiple conditions
@@ -194,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name
@@ -221,7 +245,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
result := strings.Join(validConditions, " AND ")
if result != where {
// If the original clause had outer parentheses and contains OR operators,
// restore the outer parentheses to prevent OR logic from escaping
if shouldPreserveParens {
result = "(" + result + ")"
logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result)
}
if result != where && !shouldPreserveParens {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
@@ -282,6 +313,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) {
return strings.TrimSpace(s[1 : len(s)-1]), true
}
// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses
// to prevent OR logic from escaping. It checks if the clause already has
// matching outer parentheses and only adds them if they don't exist.
//
// This is particularly important for OR conditions and complex filters where
// the absence of parentheses could cause the logic to escape and affect
// the entire query incorrectly.
//
// Parameters:
// - clause: The SQL clause to check and potentially wrap
//
// Returns:
// - The clause with guaranteed outer parentheses, or empty string if input is empty
func EnsureOuterParentheses(clause string) string {
if clause == "" {
return ""
}
clause = strings.TrimSpace(clause)
if clause == "" {
return ""
}
// Check if the clause already has matching outer parentheses
_, hasOuterParens := stripOneMatchingOuterParen(clause)
// If it already has matching outer parentheses, return as-is
if hasOuterParens {
return clause
}
// Otherwise, wrap it in parentheses
return "(" + clause + ")"
}
// containsTopLevelOR checks if a SQL clause contains OR operators at the top level
// (i.e., not inside parentheses or subqueries). This is used to determine if
// outer parentheses should be preserved to prevent OR logic from escaping.
func containsTopLevelOR(clause string) bool {
if clause == "" {
return false
}
depth := 0
inSingleQuote := false
inDoubleQuote := false
lowerClause := strings.ToLower(clause)
for i := 0; i < len(clause); i++ {
ch := clause[i]
// Track quote state
if ch == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if ch == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
// Skip if inside quotes
if inSingleQuote || inDoubleQuote {
continue
}
// Track parenthesis depth
switch ch {
case '(':
depth++
case ')':
depth--
}
// Only check for OR at depth 0 (not inside parentheses)
if depth == 0 && i+4 <= len(clause) {
// Check for " OR " (case-insensitive)
substring := lowerClause[i : i+4]
if substring == " or " {
return true
}
}
}
return false
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string {
@@ -809,3 +927,36 @@ func extractLeftSideOfComparison(cond string) string {
return ""
}
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
// Returns a single-element slice if the value is not a slice type.
func FilterValueToSlice(v interface{}) []interface{} {
if v == nil {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Slice {
result := make([]interface{}, rv.Len())
for i := 0; i < rv.Len(); i++ {
result[i] = rv.Index(i).Interface()
}
return result
}
return []interface{}{v}
}
// BuildInCondition builds a parameterized IN condition from a filter value.
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
// Returns ("", nil) if the value is empty or not a slice.
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
values := FilterValueToSlice(v)
if len(values) == 0 {
return "", nil
}
placeholders := make([]string, len(values))
for i := range values {
placeholders[i] = "?"
}
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
}

View File

@@ -0,0 +1,103 @@
package common
import (
"testing"
)
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
// are correctly handled when the tableName parameter matches the prefix
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "Correct table prefix should not be changed",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Wrong table prefix should be fixed",
where: "wrong_table.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Relation name should not replace correct table prefix",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: &RequestOptions{
Preload: []PreloadOption{
{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
},
},
},
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Unqualified column should remain unqualified",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
options: nil,
expected: "rid_parentmastertaskitem is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
// are correctly added to unqualified columns
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "Add prefix to unqualified column",
where: "rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change already qualified column",
where: "mastertaskitem.rid_parentmastertaskitem is null",
tableName: "mastertaskitem",
expected: "mastertaskitem.rid_parentmastertaskitem is null",
},
{
name: "Don't change qualified column with different table",
where: "other_table.rid_something is null",
tableName: "mastertaskitem",
expected: "other_table.rid_something is null",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AddTablePrefixToColumns(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
}
}
func TestEnsureOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no parentheses",
input: "status = 'active'",
expected: "(status = 'active')",
},
{
name: "already has outer parentheses",
input: "(status = 'active')",
expected: "(status = 'active')",
},
{
name: "OR condition without parentheses",
input: "status = 'active' OR status = 'pending'",
expected: "(status = 'active' OR status = 'pending')",
},
{
name: "OR condition with parentheses",
input: "(status = 'active' OR status = 'pending')",
expected: "(status = 'active' OR status = 'pending')",
},
{
name: "complex condition with nested parentheses",
input: "(status = 'active' OR status = 'pending') AND (age > 18)",
expected: "((status = 'active' OR status = 'pending') AND (age > 18))",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "whitespace only",
input: " ",
expected: "",
},
{
name: "mismatched parentheses - adds outer ones",
input: "(status = 'active' OR status = 'pending'",
expected: "((status = 'active' OR status = 'pending')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := EnsureOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestContainsTopLevelOR(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "no OR operator",
input: "status = 'active' AND age > 18",
expected: false,
},
{
name: "top-level OR",
input: "status = 'active' OR status = 'pending'",
expected: true,
},
{
name: "OR inside parentheses",
input: "age > 18 AND (status = 'active' OR status = 'pending')",
expected: false,
},
{
name: "OR in subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')",
expected: false,
},
{
name: "OR inside quotes",
input: "comment = 'this OR that'",
expected: false,
},
{
name: "mixed - top-level OR and nested OR",
input: "name = 'test' OR (status = 'active' OR status = 'pending')",
expected: true,
},
{
name: "empty string",
input: "",
expected: false,
},
{
name: "lowercase or",
input: "status = 'active' or status = 'pending'",
expected: true,
},
{
name: "uppercase OR",
input: "status = 'active' OR status = 'pending'",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsTopLevelOR(tt.input)
if result != tt.expected {
t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "OR condition with outer parentheses - preserved",
where: "(status = 'active' OR status = 'pending')",
tableName: "users",
expected: "(users.status = 'active' OR users.status = 'pending')",
},
{
name: "AND condition with outer parentheses - stripped (no OR)",
where: "(status = 'active' AND age > 18)",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "complex OR with nested conditions",
where: "((status = 'active' OR status = 'pending') AND age > 18)",
tableName: "users",
// Outer parens are stripped, but inner parens with OR are preserved
expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18",
},
{
name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause",
where: "status = 'active' OR status = 'pending'",
tableName: "users",
expected: "users.status = 'active' OR users.status = 'pending'",
},
{
name: "simple OR with parentheses - preserved",
where: "(users.status = 'active' OR users.status = 'pending')",
tableName: "users",
// Already has correct prefixes, parentheses preserved
expected: "(users.status = 'active' OR users.status = 'pending')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
tests := []struct {
name string

View File

@@ -23,6 +23,10 @@ type RequestOptions struct {
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
FetchRowNumber *string `json:"fetch_row_number"`
// Join table aliases (used for validation of prefixed columns in filters/sorts)
// Not serialized to JSON as it's internal validation state
JoinAliases []string `json:"-"`
}
type Parameter struct {
@@ -33,6 +37,7 @@ type Parameter struct {
type PreloadOption struct {
Relation string `json:"relation"`
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
@@ -45,9 +50,14 @@ type PreloadOption struct {
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
// Custom SQL JOINs from XFiles - used when preload needs additional joins
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
}
type FilterOption struct {

View File

@@ -237,15 +237,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort)
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
}
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
foundJoin := false
for _, j := range options.JoinAliases {
if strings.Contains(sort.Column, j) {
foundJoin = true
break
}
}
if foundJoin {
validSorts = append(validSorts, sort)
continue
}
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security
if IsSafeSortExpression(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
}
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
}
}
}
filtered.Sort = validSorts
@@ -258,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
// Preserve SqlJoins and JoinAliases for preloads with custom joins
filteredPreload.SqlJoins = preload.SqlJoins
filteredPreload.JoinAliases = preload.JoinAliases
// Filter preload filters
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
for _, filter := range preload.Filters {
if v.IsValidColumn(filter.Column) {
validPreloadFilters = append(validPreloadFilters, filter)
} else {
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
// Check if the filter column references a joined table alias
foundJoin := false
for _, alias := range preload.JoinAliases {
if strings.Contains(filter.Column, alias) {
foundJoin = true
break
}
}
if foundJoin {
validPreloadFilters = append(validPreloadFilters, filter)
} else {
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
}
}
}
filteredPreload.Filters = validPreloadFilters
@@ -291,6 +321,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
}
filtered.Preload = validPreloads
// Clear JoinAliases - this is an internal validation field and should not be persisted
filtered.JoinAliases = nil
return filtered
}

View File

@@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) {
}
}
func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Columns: []string{"id", "name"},
// Set JoinAliases - this should be cleared by FilterRequestOptions
JoinAliases: []string{"d", "u", "r"},
}
filtered := validator.FilterRequestOptions(options)
// Verify that JoinAliases was cleared (internal field should not persist)
if filtered.JoinAliases != nil {
t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases)
}
// Verify that other fields are still properly filtered
if len(filtered.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
}
}
func TestIsSafeSortExpression(t *testing.T) {
tests := []struct {
name string

View File

@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
- **GORM** - Popular Go ORM
- **Native** - Standard library `*sql.DB`
- All three share the same underlying connection pool
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
- **Configuration-Driven**: YAML configuration with Viper integration
- **Production-Ready Features**:
- Automatic health checks and reconnection
@@ -179,6 +180,35 @@ if err != nil {
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
```
#### Cross-Database Example with SQLite
```go
// Same model works across all databases
type User struct {
ID int `bun:"id,pk"`
Username string `bun:"username"`
Email string `bun:"email"`
}
func (User) TableName() string {
return "auth.users"
}
// PostgreSQL connection
pgConn, _ := mgr.Get("primary")
pgDB, _ := pgConn.Bun()
var pgUsers []User
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
// Executes: SELECT * FROM auth.users
// SQLite connection
sqliteConn, _ := mgr.Get("cache-db")
sqliteDB, _ := sqliteConn.Bun()
var sqliteUsers []User
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
```
#### Use MongoDB
```go
@@ -368,6 +398,37 @@ Providers handle:
- Connection statistics
- Connection cleanup
### SQLite Schema Handling
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
```go
// Model definition (works across all databases)
func (User) TableName() string {
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
// SQLite: "auth_users"
}
// Query execution
db.NewSelect().Model(&User{}).Scan(ctx)
// PostgreSQL/MSSQL: SELECT * FROM auth.users
// SQLite: SELECT * FROM auth_users
```
**How it Works**:
- Bun, GORM, and Native adapters detect the driver type
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
- Preload and relation queries are also handled automatically
**Benefits**:
- Write database-agnostic code
- Use the same models across PostgreSQL, MSSQL, and SQLite
- No conditional logic needed in your application
- Schema separation maintained through naming convention in SQLite
## Best Practices
1. **Use Named Connections**: Be explicit about which database you're accessing

View File

@@ -128,7 +128,7 @@ func DefaultManagerConfig() ManagerConfig {
RetryAttempts: 3,
RetryDelay: 1 * time.Second,
RetryMaxDelay: 10 * time.Second,
HealthCheckInterval: 30 * time.Second,
HealthCheckInterval: 15 * time.Second,
EnableAutoReconnect: true,
}
}
@@ -161,6 +161,11 @@ func (c *ManagerConfig) ApplyDefaults() {
if c.HealthCheckInterval == 0 {
c.HealthCheckInterval = defaults.HealthCheckInterval
}
// EnableAutoReconnect defaults to true - apply if not explicitly set
// Since this is a boolean, we apply the default unconditionally when it's false
if !c.EnableAutoReconnect {
c.EnableAutoReconnect = defaults.EnableAutoReconnect
}
}
// Validate validates the manager configuration
@@ -216,7 +221,10 @@ func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
cc.ConnectTimeout = 10 * time.Second
}
if cc.QueryTimeout == 0 {
cc.QueryTimeout = 30 * time.Second
cc.QueryTimeout = 2 * time.Minute // Default to 2 minutes
} else if cc.QueryTimeout < 2*time.Minute {
// Enforce minimum of 2 minutes
cc.QueryTimeout = 2 * time.Minute
}
// Default ORM
@@ -320,14 +328,29 @@ func (cc *ConnectionConfig) buildPostgresDSN() string {
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
}
// Add statement_timeout for query execution timeout (in milliseconds)
if cc.QueryTimeout > 0 {
timeoutMs := int(cc.QueryTimeout.Milliseconds())
dsn += fmt.Sprintf(" statement_timeout=%d", timeoutMs)
}
return dsn
}
func (cc *ConnectionConfig) buildSQLiteDSN() string {
if cc.FilePath != "" {
return cc.FilePath
filepath := cc.FilePath
if filepath == "" {
filepath = ":memory:"
}
return ":memory:"
// Add query parameters for timeouts
// Note: SQLite driver supports _timeout parameter (in milliseconds)
if cc.QueryTimeout > 0 {
timeoutMs := int(cc.QueryTimeout.Milliseconds())
filepath += fmt.Sprintf("?_timeout=%d", timeoutMs)
}
return filepath
}
func (cc *ConnectionConfig) buildMSSQLDSN() string {
@@ -339,6 +362,24 @@ func (cc *ConnectionConfig) buildMSSQLDSN() string {
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
}
// Add connection timeout (in seconds)
if cc.ConnectTimeout > 0 {
timeoutSec := int(cc.ConnectTimeout.Seconds())
dsn += fmt.Sprintf("&connection timeout=%d", timeoutSec)
}
// Add dial timeout for TCP connection (in seconds)
if cc.ConnectTimeout > 0 {
dialTimeoutSec := int(cc.ConnectTimeout.Seconds())
dsn += fmt.Sprintf("&dial timeout=%d", dialTimeoutSec)
}
// Add read timeout (in seconds) - enforces timeout for reading data
if cc.QueryTimeout > 0 {
readTimeoutSec := int(cc.QueryTimeout.Seconds())
dsn += fmt.Sprintf("&read timeout=%d", readTimeoutSec)
}
return dsn
}

View File

@@ -26,6 +26,7 @@ type Connection interface {
Bun() (*bun.DB, error)
GORM() (*gorm.DB, error)
Native() (*sql.DB, error)
DB() (*sql.DB, error)
// Common Database interface (for SQL databases)
Database() (common.Database, error)
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
return c.nativeDB, nil
}
// DB returns the underlying *sql.DB connection
func (c *sqlConnection) DB() (*sql.DB, error) {
return c.Native()
}
// Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) {
if c == nil {
@@ -467,13 +473,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
// Create a native adapter based on database type
switch c.dbType {
case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
case DatabaseTypeSQLite:
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
case DatabaseTypeMSSQL:
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
default:
return nil, ErrUnsupportedDatabase
}
@@ -647,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// DB returns an error for MongoDB connections
func (c *mongoConnection) DB() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase

View File

@@ -1,6 +1,7 @@
package dbmanager
import (
"database/sql"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
@@ -49,3 +50,18 @@ func createProvider(dbType DatabaseType) (Provider, error) {
// Provider is an alias to the providers.Provider interface
// This allows dbmanager package consumers to use Provider without importing providers
type Provider = providers.Provider
// NewConnectionFromDB creates a new Connection from an existing *sql.DB
// This allows you to use dbmanager features (ORM wrappers, health checks, etc.)
// with a database connection that was opened outside of dbmanager
//
// Parameters:
// - name: A unique name for this connection
// - dbType: The database type (DatabaseTypePostgreSQL, DatabaseTypeSQLite, or DatabaseTypeMSSQL)
// - db: An existing *sql.DB connection
//
// Returns a Connection that wraps the existing *sql.DB
func NewConnectionFromDB(name string, dbType DatabaseType, db *sql.DB) Connection {
provider := providers.NewExistingDBProvider(db, name)
return newSQLConnection(name, dbType, ConnectionConfig{Name: name, Type: dbType}, provider)
}

View File

@@ -0,0 +1,210 @@
package dbmanager
import (
"context"
"database/sql"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func TestNewConnectionFromDB(t *testing.T) {
// Open a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Create a connection from the existing database
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
if conn == nil {
t.Fatal("Expected connection to be created")
}
// Verify connection properties
if conn.Name() != "test-connection" {
t.Errorf("Expected name 'test-connection', got '%s'", conn.Name())
}
if conn.Type() != DatabaseTypeSQLite {
t.Errorf("Expected type DatabaseTypeSQLite, got '%s'", conn.Type())
}
}
func TestNewConnectionFromDB_Connect(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
// Connect should verify the existing connection works
err = conn.Connect(ctx)
if err != nil {
t.Errorf("Expected Connect to succeed, got error: %v", err)
}
// Cleanup
conn.Close()
}
func TestNewConnectionFromDB_Native(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
err = conn.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// Get native DB
nativeDB, err := conn.Native()
if err != nil {
t.Errorf("Expected Native to succeed, got error: %v", err)
}
if nativeDB != db {
t.Error("Expected Native to return the same database instance")
}
}
func TestNewConnectionFromDB_Bun(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
err = conn.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// Get Bun ORM
bunDB, err := conn.Bun()
if err != nil {
t.Errorf("Expected Bun to succeed, got error: %v", err)
}
if bunDB == nil {
t.Error("Expected Bun to return a non-nil instance")
}
}
func TestNewConnectionFromDB_GORM(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
err = conn.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// Get GORM
gormDB, err := conn.GORM()
if err != nil {
t.Errorf("Expected GORM to succeed, got error: %v", err)
}
if gormDB == nil {
t.Error("Expected GORM to return a non-nil instance")
}
}
func TestNewConnectionFromDB_HealthCheck(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
err = conn.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
// Health check should succeed
err = conn.HealthCheck(ctx)
if err != nil {
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
}
}
func TestNewConnectionFromDB_Stats(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
ctx := context.Background()
err = conn.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
stats := conn.Stats()
if stats == nil {
t.Fatal("Expected stats to be returned")
}
if stats.Name != "test-connection" {
t.Errorf("Expected stats.Name to be 'test-connection', got '%s'", stats.Name)
}
if stats.Type != DatabaseTypeSQLite {
t.Errorf("Expected stats.Type to be DatabaseTypeSQLite, got '%s'", stats.Type)
}
if !stats.Connected {
t.Error("Expected stats.Connected to be true")
}
}
func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
// This test just verifies the factory works with PostgreSQL type
// It won't actually connect since we're using SQLite
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
conn := NewConnectionFromDB("test-pg", DatabaseTypePostgreSQL, db)
if conn == nil {
t.Fatal("Expected connection to be created")
}
if conn.Type() != DatabaseTypePostgreSQL {
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
}
}

View File

@@ -219,9 +219,10 @@ func (m *connectionManager) Connect(ctx context.Context) error {
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
}
// Start background health checks if enabled
if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 {
// Always start background health checks
if m.config.HealthCheckInterval > 0 {
m.startHealthChecker()
logger.Info("Background health checker started: interval=%v", m.config.HealthCheckInterval)
}
logger.Info("Database manager initialized: connections=%d", len(m.connections))
@@ -230,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error {
// Close closes all database connections
func (m *connectionManager) Close() error {
// Stop the health checker before taking mu. performHealthCheck acquires
// a read lock, so waiting for the goroutine while holding the write lock
// would deadlock.
m.stopHealthChecker()
m.mu.Lock()
defer m.mu.Unlock()
// Stop health checker
m.stopHealthChecker()
// Close all connections
var errors []error
for name, conn := range m.connections {

View File

@@ -0,0 +1,226 @@
package dbmanager
import (
"context"
"database/sql"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
)
func TestBackgroundHealthChecker(t *testing.T) {
// Create a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Create manager config with a short health check interval for testing
cfg := ManagerConfig{
DefaultConnection: "test",
Connections: map[string]ConnectionConfig{
"test": {
Name: "test",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
},
},
HealthCheckInterval: 1 * time.Second, // Short interval for testing
EnableAutoReconnect: true,
}
// Create manager
mgr, err := NewManager(cfg)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// Connect - this should start the background health checker
ctx := context.Background()
err = mgr.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer mgr.Close()
// Get the connection to verify it's healthy
conn, err := mgr.Get("test")
if err != nil {
t.Fatalf("Failed to get connection: %v", err)
}
// Verify initial health check
err = conn.HealthCheck(ctx)
if err != nil {
t.Errorf("Initial health check failed: %v", err)
}
// Wait for a few health check cycles
time.Sleep(3 * time.Second)
// Get stats to verify the connection is still healthy
stats := conn.Stats()
if stats == nil {
t.Fatal("Expected stats to be returned")
}
if !stats.Connected {
t.Error("Expected connection to still be connected")
}
if stats.HealthCheckStatus == "" {
t.Error("Expected health check status to be set")
}
// Verify the manager has started the health checker
if cm, ok := mgr.(*connectionManager); ok {
if cm.healthTicker == nil {
t.Error("Expected health ticker to be running")
}
}
}
func TestDefaultHealthCheckInterval(t *testing.T) {
// Verify the default health check interval is 15 seconds
defaults := DefaultManagerConfig()
expectedInterval := 15 * time.Second
if defaults.HealthCheckInterval != expectedInterval {
t.Errorf("Expected default health check interval to be %v, got %v",
expectedInterval, defaults.HealthCheckInterval)
}
if !defaults.EnableAutoReconnect {
t.Error("Expected EnableAutoReconnect to be true by default")
}
}
func TestApplyDefaultsEnablesAutoReconnect(t *testing.T) {
// Create a config without setting EnableAutoReconnect
cfg := ManagerConfig{
Connections: map[string]ConnectionConfig{
"test": {
Name: "test",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
},
},
}
// Verify it's false initially (Go's zero value for bool)
if cfg.EnableAutoReconnect {
t.Error("Expected EnableAutoReconnect to be false before ApplyDefaults")
}
// Apply defaults
cfg.ApplyDefaults()
// Verify it's now true
if !cfg.EnableAutoReconnect {
t.Error("Expected EnableAutoReconnect to be true after ApplyDefaults")
}
// Verify health check interval is also set
if cfg.HealthCheckInterval != 15*time.Second {
t.Errorf("Expected health check interval to be 15s, got %v", cfg.HealthCheckInterval)
}
}
func TestManagerHealthCheck(t *testing.T) {
// Create a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Create manager config
cfg := ManagerConfig{
DefaultConnection: "test",
Connections: map[string]ConnectionConfig{
"test": {
Name: "test",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
},
},
HealthCheckInterval: 15 * time.Second,
EnableAutoReconnect: true,
}
// Create and connect manager
mgr, err := NewManager(cfg)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
ctx := context.Background()
err = mgr.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer mgr.Close()
// Perform health check on all connections
err = mgr.HealthCheck(ctx)
if err != nil {
t.Errorf("Health check failed: %v", err)
}
// Get stats
stats := mgr.Stats()
if stats == nil {
t.Fatal("Expected stats to be returned")
}
if stats.TotalConnections != 1 {
t.Errorf("Expected 1 total connection, got %d", stats.TotalConnections)
}
if stats.HealthyCount != 1 {
t.Errorf("Expected 1 healthy connection, got %d", stats.HealthyCount)
}
if stats.UnhealthyCount != 0 {
t.Errorf("Expected 0 unhealthy connections, got %d", stats.UnhealthyCount)
}
}
func TestManagerStatsAfterClose(t *testing.T) {
cfg := ManagerConfig{
DefaultConnection: "test",
Connections: map[string]ConnectionConfig{
"test": {
Name: "test",
Type: DatabaseTypeSQLite,
FilePath: ":memory:",
},
},
HealthCheckInterval: 15 * time.Second,
}
mgr, err := NewManager(cfg)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
ctx := context.Background()
err = mgr.Connect(ctx)
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
// Close the manager
err = mgr.Close()
if err != nil {
t.Errorf("Failed to close manager: %v", err)
}
// Stats should show no connections
stats := mgr.Stats()
if stats.TotalConnections != 0 {
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
}
}

View File

@@ -0,0 +1,111 @@
package providers
import (
"context"
"database/sql"
"fmt"
"sync"
"go.mongodb.org/mongo-driver/mongo"
)
// ExistingDBProvider wraps an existing *sql.DB connection
// This allows using dbmanager features with a database connection
// that was opened outside of the dbmanager package
type ExistingDBProvider struct {
db *sql.DB
name string
mu sync.RWMutex
}
// NewExistingDBProvider creates a new provider wrapping an existing *sql.DB
func NewExistingDBProvider(db *sql.DB, name string) *ExistingDBProvider {
return &ExistingDBProvider{
db: db,
name: name,
}
}
// Connect verifies the existing database connection is valid
// It does NOT create a new connection, but ensures the existing one works
func (p *ExistingDBProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
// Verify the connection works
if err := p.db.PingContext(ctx); err != nil {
return fmt.Errorf("failed to ping existing database: %w", err)
}
return nil
}
// Close closes the underlying database connection
func (p *ExistingDBProvider) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.db == nil {
return nil
}
return p.db.Close()
}
// HealthCheck verifies the connection is alive
func (p *ExistingDBProvider) HealthCheck(ctx context.Context) error {
p.mu.RLock()
defer p.mu.RUnlock()
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
return p.db.PingContext(ctx)
}
// GetNative returns the wrapped *sql.DB
func (p *ExistingDBProvider) GetNative() (*sql.DB, error) {
p.mu.RLock()
defer p.mu.RUnlock()
if p.db == nil {
return nil, fmt.Errorf("database connection is nil")
}
return p.db, nil
}
// GetMongo returns an error since this is a SQL database
func (p *ExistingDBProvider) GetMongo() (*mongo.Client, error) {
return nil, ErrNotMongoDB
}
// Stats returns connection statistics
func (p *ExistingDBProvider) Stats() *ConnectionStats {
p.mu.RLock()
defer p.mu.RUnlock()
stats := &ConnectionStats{
Name: p.name,
Type: "sql", // Generic since we don't know the specific type
Connected: p.db != nil,
}
if p.db != nil {
dbStats := p.db.Stats()
stats.OpenConnections = dbStats.OpenConnections
stats.InUse = dbStats.InUse
stats.Idle = dbStats.Idle
stats.WaitCount = dbStats.WaitCount
stats.WaitDuration = dbStats.WaitDuration
stats.MaxIdleClosed = dbStats.MaxIdleClosed
stats.MaxLifetimeClosed = dbStats.MaxLifetimeClosed
}
return stats
}

View File

@@ -0,0 +1,194 @@
package providers
import (
"context"
"database/sql"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
)
func TestNewExistingDBProvider(t *testing.T) {
// Open a SQLite in-memory database
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Create provider
provider := NewExistingDBProvider(db, "test-db")
if provider == nil {
t.Fatal("Expected provider to be created")
}
if provider.name != "test-db" {
t.Errorf("Expected name 'test-db', got '%s'", provider.name)
}
}
func TestExistingDBProvider_Connect(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
provider := NewExistingDBProvider(db, "test-db")
ctx := context.Background()
// Connect should verify the connection works
err = provider.Connect(ctx, nil)
if err != nil {
t.Errorf("Expected Connect to succeed, got error: %v", err)
}
}
func TestExistingDBProvider_Connect_NilDB(t *testing.T) {
provider := NewExistingDBProvider(nil, "test-db")
ctx := context.Background()
err := provider.Connect(ctx, nil)
if err == nil {
t.Error("Expected Connect to fail with nil database")
}
}
func TestExistingDBProvider_GetNative(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
provider := NewExistingDBProvider(db, "test-db")
nativeDB, err := provider.GetNative()
if err != nil {
t.Errorf("Expected GetNative to succeed, got error: %v", err)
}
if nativeDB != db {
t.Error("Expected GetNative to return the same database instance")
}
}
func TestExistingDBProvider_GetNative_NilDB(t *testing.T) {
provider := NewExistingDBProvider(nil, "test-db")
_, err := provider.GetNative()
if err == nil {
t.Error("Expected GetNative to fail with nil database")
}
}
func TestExistingDBProvider_HealthCheck(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
provider := NewExistingDBProvider(db, "test-db")
ctx := context.Background()
err = provider.HealthCheck(ctx)
if err != nil {
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
}
}
func TestExistingDBProvider_HealthCheck_ClosedDB(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
provider := NewExistingDBProvider(db, "test-db")
// Close the database
db.Close()
ctx := context.Background()
err = provider.HealthCheck(ctx)
if err == nil {
t.Error("Expected HealthCheck to fail with closed database")
}
}
func TestExistingDBProvider_GetMongo(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
provider := NewExistingDBProvider(db, "test-db")
_, err = provider.GetMongo()
if err != ErrNotMongoDB {
t.Errorf("Expected ErrNotMongoDB, got: %v", err)
}
}
func TestExistingDBProvider_Stats(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Set some connection pool settings to test stats
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(time.Hour)
provider := NewExistingDBProvider(db, "test-db")
stats := provider.Stats()
if stats == nil {
t.Fatal("Expected stats to be returned")
}
if stats.Name != "test-db" {
t.Errorf("Expected stats.Name to be 'test-db', got '%s'", stats.Name)
}
if stats.Type != "sql" {
t.Errorf("Expected stats.Type to be 'sql', got '%s'", stats.Type)
}
if !stats.Connected {
t.Error("Expected stats.Connected to be true")
}
}
func TestExistingDBProvider_Close(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
provider := NewExistingDBProvider(db, "test-db")
err = provider.Close()
if err != nil {
t.Errorf("Expected Close to succeed, got error: %v", err)
}
// Verify the database is closed
err = db.Ping()
if err == nil {
t.Error("Expected database to be closed")
}
}
func TestExistingDBProvider_Close_NilDB(t *testing.T) {
provider := NewExistingDBProvider(nil, "test-db")
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to succeed with nil database, got error: %v", err)
}
}

View File

@@ -4,11 +4,17 @@ import (
"context"
"database/sql"
"errors"
"strings"
"time"
"go.mongodb.org/mongo-driver/mongo"
)
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// Common errors
var (
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"sync"
"time"
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
@@ -14,8 +15,10 @@ import (
// SQLiteProvider implements Provider for SQLite databases
type SQLiteProvider struct {
db *sql.DB
config ConnectionConfig
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
config ConnectionConfig
}
// NewSQLiteProvider creates a new SQLite provider
@@ -76,8 +79,12 @@ func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) erro
// Don't fail connection if WAL mode cannot be enabled
}
// Set busy timeout to handle locked database
_, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000")
// Set busy timeout to handle locked database (minimum 2 minutes = 120000ms)
busyTimeout := cfg.GetQueryTimeout().Milliseconds()
if busyTimeout < 120000 {
busyTimeout = 120000 // Enforce minimum of 2 minutes
}
_, err = db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout=%d", busyTimeout))
if err != nil {
if cfg.GetEnableLogging() {
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
@@ -125,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
// Execute a simple query to verify the database is accessible
var result int
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
err := run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil {
return fmt.Errorf("health check failed: %w", err)
}
@@ -137,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
return nil
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
p.dbFactory = factory
return p
}
func (p *SQLiteProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *SQLiteProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// GetNative returns the native *sql.DB connection
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
if p.db == nil {

View File

@@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
return m
}
func (m *MockDatabase) DriverName() string {
return "postgres"
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64

View File

@@ -2,14 +2,38 @@ package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
// We provide auth enforcement and audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeQueryList - Auth check before list query execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
hookCtx.Abort = true
hookCtx.AbortMessage = "authentication required"
hookCtx.AbortCode = http.StatusUnauthorized
return fmt.Errorf("authentication required")
}
return nil
})
// Hook 0: BeforeQuery - Auth check before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
hookCtx.Abort = true
hookCtx.AbortMessage = "authentication required"
hookCtx.AbortCode = http.StatusUnauthorized
return fmt.Errorf("authentication required")
}
return nil
})
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)

View File

@@ -8,6 +8,10 @@ import (
// ModelRules defines the permissions and security settings for a model
type ModelRules struct {
CanPublicRead bool // Whether the model can be read (GET operations)
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanPublicCreate bool // Whether the model can be created (POST operations)
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
CanRead bool // Whether the model can be read (GET operations)
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanCreate bool // Whether the model can be created (POST operations)
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
CanUpdate: true,
CanCreate: true,
CanDelete: true,
CanPublicRead: false,
CanPublicUpdate: false,
CanPublicCreate: false,
CanPublicDelete: false,
SecurityDisabled: false,
}
}

View File

@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
- **Database Agnostic**: GORM and Bun ORM support
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
- **Thread-safe**: Proper concurrency handling throughout
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
## Lifecycle Hooks
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
### Hook Types
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
- `BeforeRead` / `AfterRead` - Read operations
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
### Security Hooks (Recommended)
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
mqttspec.RegisterSecurityHooks(handler, securityList)
// Registers BeforeHandle (model auth), BeforeRead (load rules),
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
```
### Authentication Example (JWT)
```go
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
| **Message Protocol** | Same JSON structure | Same JSON structure |
| **Hooks** | Same 12 hooks | Same 12 hooks |
| **Hooks** | Same 13 hooks | Same 13 hooks |
| **CRUD Operations** | Identical | Identical |
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |

View File

@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
},
}
// Execute BeforeHandle hook - auth check fires here, after model resolution
hookCtx.Operation = string(msg.Operation)
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
if hookCtx.Abort {
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
}
return
}
// Route to operation handler
switch msg.Operation {
case OperationRead:
@@ -645,11 +654,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
// Database operation helpers (adapted from websocketspec)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Use entity as table name
tableName := entity
if schema != "" {
tableName = schema + "." + tableName
if h.db.DriverName() == "sqlite" {
tableName = schema + "_" + tableName
} else {
tableName = schema + "." + tableName
}
}
return tableName
}

View File

@@ -20,8 +20,11 @@ type (
HookRegistry = websocketspec.HookRegistry
)
// Hook type constants - all 12 lifecycle hooks
// Hook type constants - all lifecycle hooks
const (
// BeforeHandle fires after model resolution, before operation dispatch
BeforeHandle = websocketspec.BeforeHandle
// CRUD operation hooks
BeforeRead = websocketspec.BeforeRead
AfterRead = websocketspec.AfterRead

View File

@@ -0,0 +1,108 @@
package mqttspec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LoadSecurityRules(secCtx, securityList)
})
// Hook 2: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyColumnSecurity(secCtx, securityList)
})
// Hook 3 (Optional): Audit logging
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelUpdateAllowed(secCtx)
})
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelDeleteAllowed(secCtx)
})
logger.Info("Security hooks registered for mqttspec handler")
}
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
// GetQuery retrieves a stored query from hook metadata
func (s *securityContext) GetQuery() interface{} {
if s.ctx.Metadata == nil {
return nil
}
return s.ctx.Metadata["query"]
}
// SetQuery stores the query in hook metadata
func (s *securityContext) SetQuery(query interface{}) {
if s.ctx.Metadata == nil {
s.ctx.Metadata = make(map[string]interface{})
}
s.ctx.Metadata["query"] = query
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

View File

@@ -1,6 +1,9 @@
package reflection
import "reflect"
import (
"reflect"
"strings"
)
func Len(v any) int {
val := reflect.ValueOf(v)
@@ -64,3 +67,41 @@ func GetPointerElement(v reflect.Type) reflect.Type {
}
return v
}
// GetJSONNameForField gets the JSON tag name for a struct field.
// Returns the JSON field name from the json struct tag, or an empty string if not found.
// Handles the "json" tag format: "name", "name,omitempty", etc.
func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
if modelType == nil {
return ""
}
// Handle pointer types
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
return ""
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return ""
}
// Get the JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag == "" {
return ""
}
// Parse the tag (format: "name,omitempty" or just "name")
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
return parts[0]
}
return ""
}

View File

@@ -948,29 +948,35 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
// Build list of possible column names for this field
var columnNames []string
// 1. Bun tag
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
// 2. Gorm tag
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
// 3. JSON tag
// 1. JSON tag (primary - most common)
jsonFound := false
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
columnNames = append(columnNames, parts[0])
jsonFound = true
}
}
// 4. Field name variations
// 2. Bun tag (fallback if no JSON tag)
if !jsonFound {
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
}
// 3. Gorm tag (fallback if no JSON tag)
if !jsonFound {
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
}
// 4. Field name variations (last resort)
columnNames = append(columnNames, field.Name)
columnNames = append(columnNames, strings.ToLower(field.Name))
// columnNames = append(columnNames, ToSnakeCase(field.Name))
@@ -1096,6 +1102,12 @@ func setFieldValue(field reflect.Value, value interface{}) error {
}
}
// If we can convert the type, do it
if valueReflect.Type().ConvertibleTo(field.Type()) {
field.Set(valueReflect.Convert(field.Type()))
return nil
}
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
if field.Kind() == reflect.Struct {
@@ -1107,9 +1119,9 @@ func setFieldValue(field reflect.Value, value interface{}) error {
// Call the Scan method with the value
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
if len(results) > 0 {
// Check if there was an error
if err, ok := results[0].Interface().(error); ok && err != nil {
return err
// The Scan method returns error - check if it's nil
if !results[0].IsNil() {
return results[0].Interface().(error)
}
return nil
}
@@ -1164,12 +1176,6 @@ func setFieldValue(field reflect.Value, value interface{}) error {
}
// If we can convert the type, do it
if valueReflect.Type().ConvertibleTo(field.Type()) {
field.Set(valueReflect.Convert(field.Type()))
return nil
}
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
}
@@ -1364,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
return 0, false
}
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
// This can be used to validate input data against a model's structure
// The map keys are the JSON field names (from json tags) that exist in the model
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
validFields := make(map[string]bool)
// Unwrap pointers to get to the base struct type
for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return validFields
}
collectValidFieldNames(modelType, validFields)
return validFields
}
// collectValidFieldNames recursively collects valid JSON field names from a struct type
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Check for embedded structs
if field.Anonymous {
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Struct {
// Recursively add fields from embedded struct
collectValidFieldNames(fieldType, validFields)
continue
}
}
// Get the JSON tag name for this field (same logic as MapToStruct)
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract the field name from the JSON tag (before any options like omitempty)
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
validFields[parts[0]] = true
}
} else {
// If no JSON tag, use the field name in lowercase as a fallback
validFields[strings.ToLower(field.Name)] = true
}
}
}
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
// This is a helper function used by GetRelationModel to handle one level at a time
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {

View File

@@ -0,0 +1,120 @@
package reflection_test
import (
"database/sql"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestMapToStruct_StandardSqlNullTypes(t *testing.T) {
// Test model with standard library sql.Null* types
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Age sql.NullInt64 `bun:"age" json:"age"`
Name sql.NullString `bun:"name" json:"name"`
Score sql.NullFloat64 `bun:"score" json:"score"`
Active sql.NullBool `bun:"active" json:"active"`
UpdatedAt sql.NullTime `bun:"updated_at" json:"updated_at"`
}
now := time.Now()
dataMap := map[string]any{
"id": int64(100),
"age": int64(25),
"name": "John Doe",
"score": 95.5,
"active": true,
"updated_at": now,
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
// Verify ID
if result.ID != 100 {
t.Errorf("ID = %v, want 100", result.ID)
}
// Verify Age (sql.NullInt64)
if !result.Age.Valid {
t.Error("Age.Valid = false, want true")
}
if result.Age.Int64 != 25 {
t.Errorf("Age.Int64 = %v, want 25", result.Age.Int64)
}
// Verify Name (sql.NullString)
if !result.Name.Valid {
t.Error("Name.Valid = false, want true")
}
if result.Name.String != "John Doe" {
t.Errorf("Name.String = %v, want 'John Doe'", result.Name.String)
}
// Verify Score (sql.NullFloat64)
if !result.Score.Valid {
t.Error("Score.Valid = false, want true")
}
if result.Score.Float64 != 95.5 {
t.Errorf("Score.Float64 = %v, want 95.5", result.Score.Float64)
}
// Verify Active (sql.NullBool)
if !result.Active.Valid {
t.Error("Active.Valid = false, want true")
}
if !result.Active.Bool {
t.Error("Active.Bool = false, want true")
}
// Verify UpdatedAt (sql.NullTime)
if !result.UpdatedAt.Valid {
t.Error("UpdatedAt.Valid = false, want true")
}
if !result.UpdatedAt.Time.Equal(now) {
t.Errorf("UpdatedAt.Time = %v, want %v", result.UpdatedAt.Time, now)
}
t.Log("All standard library sql.Null* types handled correctly!")
}
func TestMapToStruct_StandardSqlNullTypes_WithNil(t *testing.T) {
// Test nil handling for standard library sql.Null* types
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Age sql.NullInt64 `bun:"age" json:"age"`
Name sql.NullString `bun:"name" json:"name"`
}
dataMap := map[string]any{
"id": int64(200),
"age": int64(30),
"name": nil, // Explicitly nil
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
// Age should be valid
if !result.Age.Valid {
t.Error("Age.Valid = false, want true")
}
if result.Age.Int64 != 30 {
t.Errorf("Age.Int64 = %v, want 30", result.Age.Int64)
}
// Name should be invalid (null)
if result.Name.Valid {
t.Error("Name.Valid = true, want false (null)")
}
t.Log("Nil handling for sql.Null* types works correctly!")
}

View File

@@ -0,0 +1,364 @@
package reflection
import (
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
"github.com/google/uuid"
)
// TestModel contains all spectypes custom types
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Name spectypes.SqlString `bun:"name" json:"name"`
Age spectypes.SqlInt64 `bun:"age" json:"age"`
Score spectypes.SqlFloat64 `bun:"score" json:"score"`
Active spectypes.SqlBool `bun:"active" json:"active"`
UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"`
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"`
Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"`
Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"`
Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"`
}
// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly
func TestMapToStruct_AllSpectypes(t *testing.T) {
testUUID := uuid.New()
testTime := time.Now()
tests := []struct {
name string
dataMap map[string]interface{}
validator func(*testing.T, *TestModel)
}{
{
name: "SqlString from string",
dataMap: map[string]interface{}{
"name": "John Doe",
},
validator: func(t *testing.T, m *TestModel) {
if !m.Name.Valid || m.Name.String() != "John Doe" {
t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
}
},
},
{
name: "SqlInt64 from int64",
dataMap: map[string]interface{}{
"age": int64(42),
},
validator: func(t *testing.T, m *TestModel) {
if !m.Age.Valid || m.Age.Int64() != 42 {
t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
}
},
},
{
name: "SqlInt64 from string",
dataMap: map[string]interface{}{
"age": "99",
},
validator: func(t *testing.T, m *TestModel) {
if !m.Age.Valid || m.Age.Int64() != 99 {
t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
}
},
},
{
name: "SqlFloat64 from float64",
dataMap: map[string]interface{}{
"score": float64(98.5),
},
validator: func(t *testing.T, m *TestModel) {
if !m.Score.Valid || m.Score.Float64() != 98.5 {
t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
}
},
},
{
name: "SqlBool from bool",
dataMap: map[string]interface{}{
"active": true,
},
validator: func(t *testing.T, m *TestModel) {
if !m.Active.Valid || !m.Active.Bool() {
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
}
},
},
{
name: "SqlUUID from string",
dataMap: map[string]interface{}{
"uuid": testUUID.String(),
},
validator: func(t *testing.T, m *TestModel) {
if !m.UUID.Valid || m.UUID.UUID() != testUUID {
t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String())
}
},
},
{
name: "SqlTimeStamp from time.Time",
dataMap: map[string]interface{}{
"created_at": testTime,
},
validator: func(t *testing.T, m *TestModel) {
if !m.CreatedAt.Valid {
t.Errorf("expected created_at to be valid")
}
// Check if times are close enough (within a second)
diff := m.CreatedAt.Time().Sub(testTime)
if diff < -time.Second || diff > time.Second {
t.Errorf("time difference too large: %v", diff)
}
},
},
{
name: "SqlTimeStamp from string",
dataMap: map[string]interface{}{
"created_at": "2024-01-15T10:30:00",
},
validator: func(t *testing.T, m *TestModel) {
if !m.CreatedAt.Valid {
t.Errorf("expected created_at to be valid")
}
expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
if m.CreatedAt.Time().Year() != expected.Year() ||
m.CreatedAt.Time().Month() != expected.Month() ||
m.CreatedAt.Time().Day() != expected.Day() {
t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time())
}
},
},
{
name: "SqlDate from string",
dataMap: map[string]interface{}{
"birth_date": "2000-05-20",
},
validator: func(t *testing.T, m *TestModel) {
if !m.BirthDate.Valid {
t.Errorf("expected birth_date to be valid")
}
expected := "2000-05-20"
if m.BirthDate.String() != expected {
t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String())
}
},
},
{
name: "SqlTime from string",
dataMap: map[string]interface{}{
"start_time": "14:30:00",
},
validator: func(t *testing.T, m *TestModel) {
if !m.StartTime.Valid {
t.Errorf("expected start_time to be valid")
}
if m.StartTime.String() != "14:30:00" {
t.Errorf("expected time=14:30:00, got %s", m.StartTime.String())
}
},
},
{
name: "SqlJSONB from map",
dataMap: map[string]interface{}{
"metadata": map[string]interface{}{
"key1": "value1",
"key2": 123,
},
},
validator: func(t *testing.T, m *TestModel) {
if len(m.Metadata) == 0 {
t.Errorf("expected metadata to have data")
}
asMap, err := m.Metadata.AsMap()
if err != nil {
t.Fatalf("failed to convert metadata to map: %v", err)
}
if asMap["key1"] != "value1" {
t.Errorf("expected key1=value1, got %v", asMap["key1"])
}
},
},
{
name: "SqlJSONB from string",
dataMap: map[string]interface{}{
"metadata": `{"test":"data"}`,
},
validator: func(t *testing.T, m *TestModel) {
if len(m.Metadata) == 0 {
t.Errorf("expected metadata to have data")
}
asMap, err := m.Metadata.AsMap()
if err != nil {
t.Fatalf("failed to convert metadata to map: %v", err)
}
if asMap["test"] != "data" {
t.Errorf("expected test=data, got %v", asMap["test"])
}
},
},
{
name: "SqlJSONB from []byte",
dataMap: map[string]interface{}{
"metadata": []byte(`{"byte":"array"}`),
},
validator: func(t *testing.T, m *TestModel) {
if len(m.Metadata) == 0 {
t.Errorf("expected metadata to have data")
}
if string(m.Metadata) != `{"byte":"array"}` {
t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata))
}
},
},
{
name: "SqlInt16 from int16",
dataMap: map[string]interface{}{
"count16": int16(100),
},
validator: func(t *testing.T, m *TestModel) {
if !m.Count16.Valid || m.Count16.Int64() != 100 {
t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
}
},
},
{
name: "SqlInt32 from int32",
dataMap: map[string]interface{}{
"count32": int32(5000),
},
validator: func(t *testing.T, m *TestModel) {
if !m.Count32.Valid || m.Count32.Int64() != 5000 {
t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
}
},
},
{
name: "nil values create invalid nulls",
dataMap: map[string]interface{}{
"name": nil,
"age": nil,
"active": nil,
"created_at": nil,
},
validator: func(t *testing.T, m *TestModel) {
if m.Name.Valid {
t.Error("expected name to be invalid for nil value")
}
if m.Age.Valid {
t.Error("expected age to be invalid for nil value")
}
if m.Active.Valid {
t.Error("expected active to be invalid for nil value")
}
if m.CreatedAt.Valid {
t.Error("expected created_at to be invalid for nil value")
}
},
},
{
name: "all types together",
dataMap: map[string]interface{}{
"id": int64(1),
"name": "Test User",
"age": int64(30),
"score": float64(95.7),
"active": true,
"uuid": testUUID.String(),
"created_at": "2024-01-15T10:30:00",
"birth_date": "1994-06-15",
"start_time": "09:00:00",
"metadata": map[string]interface{}{"role": "admin"},
"count16": int16(50),
"count32": int32(1000),
},
validator: func(t *testing.T, m *TestModel) {
if m.ID != 1 {
t.Errorf("expected id=1, got %d", m.ID)
}
if !m.Name.Valid || m.Name.String() != "Test User" {
t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
}
if !m.Age.Valid || m.Age.Int64() != 30 {
t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
}
if !m.Score.Valid || m.Score.Float64() != 95.7 {
t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
}
if !m.Active.Valid || !m.Active.Bool() {
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
}
if !m.UUID.Valid {
t.Error("expected uuid to be valid")
}
if !m.CreatedAt.Valid {
t.Error("expected created_at to be valid")
}
if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" {
t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String())
}
if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" {
t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String())
}
if len(m.Metadata) == 0 {
t.Error("expected metadata to have data")
}
if !m.Count16.Valid || m.Count16.Int64() != 50 {
t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
}
if !m.Count32.Valid || m.Count32.Int64() != 1000 {
t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model := &TestModel{}
if err := MapToStruct(tt.dataMap, model); err != nil {
t.Fatalf("MapToStruct failed: %v", err)
}
tt.validator(t, model)
})
}
}
// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields
func TestMapToStruct_PartialUpdate(t *testing.T) {
// Create initial model with some values
initial := &TestModel{
ID: 1,
Name: spectypes.NewSqlString("Original Name"),
Age: spectypes.NewSqlInt64(25),
}
// Update only the age field
partialUpdate := map[string]interface{}{
"age": int64(30),
}
// Apply partial update
if err := MapToStruct(partialUpdate, initial); err != nil {
t.Fatalf("MapToStruct failed: %v", err)
}
// Verify age was updated
if !initial.Age.Valid || initial.Age.Int64() != 30 {
t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64())
}
// Verify name was preserved (not overwritten with zero value)
if !initial.Name.Valid || initial.Name.String() != "Original Name" {
t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String())
}
// Verify ID was preserved
if initial.ID != 1 {
t.Errorf("expected id=1 to be preserved, got %d", initial.ID)
}
}

513
pkg/resolvemcp/README.md Normal file
View File

@@ -0,0 +1,513 @@
# resolvemcp
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
## Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
"github.com/gorilla/mux"
)
// 1. Create a handler
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
BaseURL: "http://localhost:8080",
})
// 2. Register models
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "orders", &Order{})
// 3. Mount routes
r := mux.NewRouter()
resolvemcp.SetupMuxRoutes(r, handler)
```
---
## Config
```go
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
// If empty, it is detected from each incoming request using the Host header and
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
BaseURL string
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
// Required.
BasePath string
}
```
## Handler Creation
| Function | Description |
|---|---|
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
---
## Registering Models
```go
handler.RegisterModel(schema, entity string, model interface{}) error
```
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
- `entity` — table/entity name (e.g. `"users"`).
- `model` — a pointer to a struct (e.g. `&User{}`).
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
---
## HTTP / SSE Transport
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
`Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request.
### Gorilla Mux
```go
resolvemcp.SetupMuxRoutes(r, handler)
```
Registers:
| Route | Method | Description |
|---|---|---|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
### bunrouter
```go
resolvemcp.SetupBunRouterRoutes(router, handler)
```
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`.
### Gin (or any `http.Handler`-compatible framework)
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
```go
sse := handler.SSEServer()
// Gin
engine.Any("/mcp/*path", gin.WrapH(sse))
// net/http
http.Handle("/mcp/", sse)
// Echo
e.Any("/mcp/*", echo.WrapHandler(sse))
```
### Authentication
Add middleware before the MCP routes. The handler itself has no auth layer.
---
## Security
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
### Wiring security hooks
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
securityList := security.NewSecurityList(mySecurityProvider)
resolvemcp.RegisterSecurityHooks(handler, securityList)
```
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
| Hook | Effect |
|---|---|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
### Per-entity operation rules
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
```go
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
// Read-only entity
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
CanRead: true,
CanCreate: false,
CanUpdate: false,
CanDelete: false,
})
// Public read, authenticated write
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
CanPublicRead: true,
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
To update rules for an already-registered model:
```go
handler.SetModelRules("public", "users", modelregistry.ModelRules{
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
### ModelRules fields
| Field | Default | Description |
|---|---|---|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
| `CanRead` | `true` | Allow authenticated reads |
| `CanCreate` | `true` | Allow authenticated creates |
| `CanUpdate` | `true` | Allow authenticated updates |
| `CanDelete` | `true` | Allow authenticated deletes |
| `SecurityDisabled` | `false` | Skip all security checks for this model |
---
## MCP Tools
### Tool Naming
```
{operation}_{schema}_{entity} // e.g. read_public_users
{operation}_{entity} // e.g. read_users (when schema is empty)
```
Operations: `read`, `create`, `update`, `delete`.
### Read Tool — `read_{schema}_{entity}`
Fetch one or many records.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key value. Omit to return multiple records. |
| `limit` | number | Max records per page (recommended: 10100). |
| `offset` | number | Records to skip (offset-based pagination). |
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
| `columns` | array | Column names to include. Omit for all columns. |
| `omit_columns` | array | Column names to exclude. |
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
**Response:**
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"count": 10,
"limit": 10,
"offset": 0
}
}
```
### Create Tool — `create_{schema}_{entity}`
Insert one or more records.
| Argument | Type | Description |
|---|---|---|
| `data` | object \| array | Single object or array of objects to insert. |
Array input runs inside a single transaction — all succeed or all fail.
**Response:**
```json
{ "success": true, "data": { ... } }
```
### Update Tool — `update_{schema}_{entity}`
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key of the record. Can also be included inside `data`. |
| `data` | object (required) | Fields to update. |
**Response:**
```json
{ "success": true, "data": { ...merged record... } }
```
### Delete Tool — `delete_{schema}_{entity}`
Delete a record by primary key. **Irreversible.**
| Argument | Type | Description |
|---|---|---|
| `id` | string (required) | Primary key of the record to delete. |
**Response:**
```json
{ "success": true, "data": { ...deleted record... } }
```
### Annotation Tool — `resolvespec_annotate`
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
| Argument | Type | Description |
|---|---|---|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
```json
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "set" }
```
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
```json
{ "tool_name": "read_public_users" }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
```
---
### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
---
## Filtering
Pass an array of filter objects to the `filters` argument:
```json
[
{ "column": "status", "operator": "=", "value": "active" },
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
]
```
### Supported Operators
| Operator | Aliases | Description |
|---|---|---|
| `=` | `eq` | Equal |
| `!=` | `neq`, `<>` | Not equal |
| `>` | `gt` | Greater than |
| `>=` | `gte` | Greater than or equal |
| `<` | `lt` | Less than |
| `<=` | `lte` | Less than or equal |
| `like` | | SQL LIKE (case-sensitive) |
| `ilike` | | SQL ILIKE (case-insensitive) |
| `in` | | Value in list |
| `is_null` | | Column IS NULL |
| `is_not_null` | | Column IS NOT NULL |
### Logic Operators
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
---
## Sorting
```json
[
{ "column": "created_at", "direction": "desc" },
{ "column": "name", "direction": "asc" }
]
```
---
## Pagination
### Offset-Based
```json
{ "limit": 20, "offset": 40 }
```
### Cursor-Based
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
```json
// Next page: pass the PK of the last record on the current page
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
// Previous page: pass the PK of the first record on the current page
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
```
---
## Preloading Relations
```json
[
{ "relation": "Profile" },
{ "relation": "Orders" }
]
```
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
---
## Hook System
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
### Hook Types
| Constant | Fires |
|---|---|
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
| `BeforeRead` / `AfterRead` | Around read queries |
| `BeforeCreate` / `AfterCreate` | Around insert |
| `BeforeUpdate` / `AfterUpdate` | Around update |
| `BeforeDelete` / `AfterDelete` | Around delete |
### Registering Hooks
```go
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
// Inject a timestamp before insert
if data, ok := ctx.Data.(map[string]interface{}); ok {
data["created_at"] = time.Now()
}
return nil
})
// Register the same hook for multiple events
handler.Hooks().RegisterMultiple(
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
auditHook,
)
```
### HookContext Fields
| Field | Type | Description |
|---|---|---|
| `Context` | `context.Context` | Request context |
| `Handler` | `*Handler` | The resolvemcp handler |
| `Schema` | `string` | Database schema name |
| `Entity` | `string` | Entity/table name |
| `Model` | `interface{}` | Registered model instance |
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
| `ID` | `string` | Primary key from request (read/update/delete) |
| `Data` | `interface{}` | Input data (create/update — modifiable) |
| `Result` | `interface{}` | Output data (set by After hooks) |
| `Error` | `error` | Operation error, if any |
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
| `Tx` | `common.Database` | Database/transaction handle |
| `Abort` | `bool` | Set to `true` to abort the operation |
| `AbortMessage` | `string` | Error message returned when aborting |
| `AbortCode` | `int` | Optional status code for the abort |
### Aborting an Operation
```go
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "deletion is disabled"
return nil
})
```
### Managing Hooks
```go
registry := handler.Hooks()
registry.HasHooks(resolvemcp.BeforeCreate) // bool
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
registry.ClearAll() // remove all hooks
```
---
## Context Helpers
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
```go
schema := resolvemcp.GetSchema(ctx)
entity := resolvemcp.GetEntity(ctx)
tableName := resolvemcp.GetTableName(ctx)
model := resolvemcp.GetModel(ctx)
modelPtr := resolvemcp.GetModelPtr(ctx)
```
You can also set values manually (e.g. in middleware):
```go
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
```
---
## Adding Custom MCP Tools
Access the underlying `*server.MCPServer` to register additional tools:
```go
mcpServer := handler.MCPServer()
mcpServer.AddTool(myTool, myHandler)
```
---
## Table Name Resolution
The handler resolves table names in priority order:
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)

View File

@@ -0,0 +1,107 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
)
const annotationToolName = "resolvespec_annotate"
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
// The tool lets models/entities store and retrieve freeform annotation records
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
func registerAnnotationTool(h *Handler) {
tool := mcp.NewTool(annotationToolName,
mcp.WithDescription(
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
"To set annotations: provide both 'tool_name' and 'annotations'. "+
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
"To get annotations: provide only 'tool_name'. "+
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
"a model/entity name (e.g. 'public.users'), or any other key.",
),
mcp.WithString("tool_name",
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
mcp.Required(),
),
mcp.WithObject("annotations",
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
toolName, ok := args["tool_name"].(string)
if !ok || toolName == "" {
return mcp.NewToolResultError("missing required argument: tool_name"), nil
}
annotations, hasAnnotations := args["annotations"]
if hasAnnotations && annotations != nil {
return executeSetAnnotation(ctx, h, toolName, annotations)
}
return executeGetAnnotation(ctx, h, toolName)
})
}
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(annotations)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
}
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "set",
})
}
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
var rows []map[string]interface{}
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
}
var annotations interface{}
if len(rows) > 0 {
// The procedure returns a single value; extract the first column of the first row.
for _, v := range rows[0] {
annotations = v
break
}
}
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
switch v := annotations.(type) {
case []byte:
var decoded interface{}
if json.Unmarshal(v, &decoded) == nil {
annotations = decoded
}
case string:
var decoded interface{}
if json.Unmarshal([]byte(v), &decoded) == nil {
annotations = decoded
}
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "get",
"annotations": annotations,
})
}

71
pkg/resolvemcp/context.go Normal file
View File

@@ -0,0 +1,71 @@
package resolvemcp
import "context"
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

161
pkg/resolvemcp/cursor.go Normal file
View File

@@ -0,0 +1,161 @@
package resolvemcp
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
type cursorDirection int
const (
cursorForward cursorDirection = 1
cursorBackward cursorDirection = -1
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
var whereClauses []string
joinSQL := ""
reverse := direction < 0
for _, s := range sortItems {
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
orSQL := buildCursorPriorityChain(whereClauses)
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, cursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, cursorBackward
}
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
func buildCursorPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

736
pkg/resolvemcp/handler.go Normal file
View File

@@ -0,0 +1,736 @@
package resolvemcp
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/mark3labs/mcp-go/server"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler exposes registered database models as MCP tools and resources.
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
mcpServer *server.MCPServer
config Config
name string
version string
}
// NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
h := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
config: cfg,
name: "resolvemcp",
version: "1.0.0",
}
registerAnnotationTool(h)
return h
}
// Hooks returns the hook registry.
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// GetDatabase returns the underlying database.
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer
}
// SSEServer returns an http.Handler that serves MCP over SSE.
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
// detected automatically from each incoming request.
func (h *Handler) SSEServer() http.Handler {
if h.config.BaseURL != "" {
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
}
return &dynamicSSEHandler{h: h}
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
h.mcpServer,
server.WithBaseURL(baseURL),
server.WithStaticBasePath(basePath),
)
}
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
type dynamicSSEHandler struct {
h *Handler
mu sync.Mutex
pool map[string]*server.SSEServer
}
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseURL := requestBaseURL(r)
d.mu.Lock()
if d.pool == nil {
d.pool = make(map[string]*server.SSEServer)
}
s, ok := d.pool[baseURL]
if !ok {
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
d.pool[baseURL] = s
}
d.mu.Unlock()
s.ServeHTTP(w, r)
}
// requestBaseURL builds the base URL from an incoming request.
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
func requestBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
}
return scheme + "://" + r.Host
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity)
if err := h.registry.RegisterModel(fullName, model); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// RegisterModelWithRules registers a model and sets per-entity operation rules
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
fullName := buildModelName(schema, entity)
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// SetModelRules updates the operation rules for an already-registered model.
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
return reg.SetModelRules(buildModelName(schema, entity), rules)
}
// buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string {
if schema == "" {
return entity
}
return fmt.Sprintf("%s.%s", schema, entity)
}
// getTableName returns the fully qualified table name for a model.
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
if idx := strings.LastIndex(tableName, "."); idx != -1 {
return tableName[:idx], tableName[idx+1:]
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), tableName
}
return defaultSchema, tableName
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), entity
}
return defaultSchema, entity
}
// executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err)
}
unwrapped, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, nil, fmt.Errorf("invalid model: %w", err)
}
model = unwrapped.Model
modelType := unwrapped.ModelType
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
validator := common.NewColumnValidator(model)
options = validator.FilterRequestOptions(options)
// BeforeHandle hook
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "read",
Options: options,
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, nil, err
}
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
query := h.db.NewSelect().Model(modelPtr)
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Column selection
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
options.Columns = reflection.GetSQLModelColumns(model)
}
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
for _, cu := range options.ComputedColumns {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters
query = h.applyFilters(query, options.Filters)
// Custom operators
for _, customOp := range options.CustomOperators {
query = query.Where(customOp.SQL)
}
// Sorting
for _, sort := range options.Sort {
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Cursor pagination
if options.CursorForward != "" || options.CursorBackward != "" {
pkName := reflection.GetPrimaryKeyName(model)
modelColumns := reflection.GetModelColumns(model)
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}
if cursorFilter != "" {
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
sanitized = common.EnsureOuterParentheses(sanitized)
if sanitized != "" {
query = query.Where(sanitized)
}
}
}
// Count
total, err := query.Count(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err)
}
// Pagination
if options.Limit != nil && *options.Limit > 0 {
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
query = query.Offset(*options.Offset)
}
// BeforeRead hook
hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
return nil, nil, err
}
var data interface{}
if id != "" {
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := query.Scan(ctx, singleResult); err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("record not found")
}
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = singleResult
} else {
if err := query.Scan(ctx, modelPtr); err != nil {
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = reflect.ValueOf(modelPtr).Elem().Interface()
}
limit := 0
offset := 0
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Count is the number of records in this page, not the total.
var pageCount int64
if id != "" {
pageCount = 1
} else {
pageCount = int64(reflect.ValueOf(data).Len())
}
metadata := &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Count: pageCount,
Limit: limit,
Offset: offset,
}
// AfterRead hook
hookCtx.Result = data
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
return nil, nil, err
}
return data, metadata, nil
}
// executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "create",
Data: data,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
return nil, err
}
// Use potentially modified data
data = hookCtx.Data
switch v := data.(type) {
case map[string]interface{}:
query := h.db.NewInsert().Table(tableName)
for key, value := range v {
query = query.Value(key, value)
}
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return v, nil
case []interface{}:
results := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
itemMap, ok := item.(map[string]interface{})
if !ok {
return fmt.Errorf("each item must be an object")
}
q := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
q = q.Value(key, value)
}
if _, err := q.Exec(ctx); err != nil {
return err
}
results = append(results, item)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("batch create error: %w", err)
}
hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return results, nil
default:
return nil, fmt.Errorf("data must be an object or array of objects")
}
}
// executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
updates, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("data must be an object")
}
if id == "" {
if idVal, exists := updates["id"]; exists {
id = fmt.Sprintf("%v", idVal)
}
}
if id == "" {
return nil, fmt.Errorf("update requires an ID")
}
pkName := reflection.GetPrimaryKeyName(model)
var updateResult interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
existingRecord := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Convert to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "update",
ID: id,
Data: updates,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return err
}
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge non-nil, non-empty values
for key, newValue := range updates {
if newValue == nil {
continue
}
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
existingMap[key] = newValue
}
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
res, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
updateResult = existingMap
hookCtx.Result = updateResult
return h.hooks.Execute(AfterUpdate, hookCtx)
})
if err != nil {
return nil, err
}
return updateResult, nil
}
// executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
if id == "" {
return nil, fmt.Errorf("delete requires an ID")
}
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
pkName := reflection.GetPrimaryKeyName(model)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "delete",
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
return nil, err
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
var recordToDelete interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
record := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(record).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found")
}
return fmt.Errorf("error fetching record: %w", err)
}
res, err := tx.NewDelete().Table(tableName).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete error: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("record not found or already deleted")
}
recordToDelete = record
hookCtx.Tx = tx
hookCtx.Result = record
return h.hooks.Execute(AfterDelete, hookCtx)
})
if err != nil {
return nil, err
}
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
return recordToDelete, nil
}
// applyFilters applies all filters with OR grouping logic.
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
}
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
switch filter.Operator {
case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
case "neq", "!=", "<>":
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
case "gt", ">":
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
case "gte", ">=":
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
case "lt", "<":
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
case "lte", "<=":
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
case "in":
condition, args := common.BuildInCondition(filter.Column, filter.Value)
return condition, args
case "is_null":
return fmt.Sprintf("%s IS NULL", filter.Column), nil
case "is_not_null":
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
}
return "", nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for i := range preloads {
preload := &preloads[i]
if preload.Relation == "" {
continue
}
query = query.PreloadRelation(preload.Relation)
}
return query, nil
}

113
pkg/resolvemcp/hooks.go Normal file
View File

@@ -0,0 +1,113 @@
package resolvemcp
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
BeforeHandle HookType = "before_handle"
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Operation string
ID string
Data interface{}
Result interface{}
Error error
Query common.SelectQuery
Abort bool
AbortMessage string
AbortCode int
Tx common.Database
}
// HookFunc is the signature for hook functions
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
if ctx.Abort {
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
}
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
}
func (r *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := r.hooks[hookType]
return exists && len(hooks) > 0
}

View File

@@ -0,0 +1,100 @@
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
// and resources over HTTP/SSE transport.
//
// It mirrors the resolvespec package patterns:
// - Same model registration API
// - Same filter, sort, cursor pagination, preload options
// - Same lifecycle hook system
//
// Usage:
//
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
// handler.RegisterModel("public", "users", &User{})
//
// r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler)
package resolvemcp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
bunrouter "github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Config holds configuration for the resolvemcp handler.
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
BaseURL string
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
// If empty, the path is detected from each incoming request automatically.
BasePath string
}
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
}
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
//
// To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at basePath for clients that
// use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
// using the base path from Config.BasePath.
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint
// - POST {basePath}/message — JSON-RPC message endpoint
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
}
// NewSSEServer returns an http.Handler that serves MCP over SSE.
// If Config.BasePath is set it is used directly; otherwise the base path is
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
//
// h := resolvemcp.NewSSEServer(handler)
// http.Handle("/api/mcp/", h)
func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer()
}

View File

@@ -0,0 +1,115 @@
package resolvemcp
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks wires the security package's access-control layer into the
// resolvemcp handler. Call it once after creating the handler, before registering models.
//
// The following controls are applied:
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
// stored via RegisterModelWithRules / SetModelRules.
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
// - Column-level security: sensitive columns masked/hidden in read results.
// - Audit logging after each read.
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// BeforeHandle: enforce model-level operation rules (auth check).
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
})
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (2nd): audit log.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.LogDataAccess(newSecurityContext(hookCtx))
})
// BeforeUpdate: enforce CanUpdate rule.
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
})
// BeforeDelete: enforce CanDelete rule.
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
})
logger.Info("Security hooks registered for resolvemcp handler")
}
// --------------------------------------------------------------------------
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
// --------------------------------------------------------------------------
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

692
pkg/resolvemcp/tools.go Normal file
View File

@@ -0,0 +1,692 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// toolName builds the MCP tool name for a given operation and model.
func toolName(operation, schema, entity string) string {
if schema == "" {
return fmt.Sprintf("%s_%s", operation, entity)
}
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
}
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
info := buildModelInfo(schema, entity, model)
registerReadTool(h, schema, entity, info)
registerCreateTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
),
mcp.WithString("cursor_forward",
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithString("cursor_backward",
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithArray("columns",
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("omit_columns",
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("filters",
mcp.Description(filterDesc),
),
mcp.WithArray("sort",
mcp.Description(sortDesc),
),
mcp.WithArray("preloads",
mcp.Description(buildPreloadDesc(info)),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
options := parseRequestOptions(args)
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": data,
"metadata": metadata,
})
})
}
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
result, err := h.executeCreate(ctx, schema, entity, data)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(idDesc),
),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
dataMap, ok := data.(map[string]interface{})
if !ok {
return mcp.NewToolResultError("data must be an object"), nil
}
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity)
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
result, err := h.executeDelete(ctx, schema, entity, id)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := info.fullName
var resourceDesc strings.Builder
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
if info.pkName != "" {
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
}
resource := mcp.NewResource(
resourceURI,
entity,
mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"),
)
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
limit := 100
options := common.RequestOptions{Limit: &limit}
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
if err != nil {
return nil, err
}
payload := map[string]interface{}{
"data": data,
"metadata": metadata,
}
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling resource: %w", err)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "application/json",
Text: string(jsonBytes),
},
}, nil
})
}
// --------------------------------------------------------------------------
// Argument parsing helpers
// --------------------------------------------------------------------------
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
limit := int(n)
options.Limit = &limit
case int:
options.Limit = &n
}
}
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
offset := int(n)
options.Offset = &offset
case int:
options.Offset = &n
}
}
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
if v, ok := args["cursor_backward"].(string); ok {
options.CursorBackward = v
}
options.Columns = parseStringArray(args["columns"])
options.OmitColumns = parseStringArray(args["omit_columns"])
options.Filters = parseFilters(args["filters"])
options.Sort = parseSortOptions(args["sort"])
options.Preload = parsePreloadOptions(args["preloads"])
return options
}
func parseStringArray(raw interface{}) []string {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]string, 0, len(items))
for _, item := range items {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
func parseFilters(raw interface{}) []common.FilterOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.FilterOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var f common.FilterOption
if err := json.Unmarshal(b, &f); err != nil {
continue
}
if f.Column == "" || f.Operator == "" {
continue
}
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {
f.LogicOperator = "AND"
}
result = append(result, f)
}
return result
}
func parseSortOptions(raw interface{}) []common.SortOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.SortOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var s common.SortOption
if err := json.Unmarshal(b, &s); err != nil {
continue
}
if s.Column == "" {
continue
}
result = append(result, s)
}
return result
}
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.PreloadOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var p common.PreloadOption
if err := json.Unmarshal(b, &p); err != nil {
continue
}
if p.Relation == "" {
continue
}
result = append(result, p)
}
return result
}
// marshalResult marshals a value to JSON and returns it as an MCP text result.
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
b, err := json.Marshal(v)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
}
return mcp.NewToolResultText(string(b)), nil
}

572
pkg/resolvespec/EXAMPLES.md Normal file
View File

@@ -0,0 +1,572 @@
# ResolveSpec Query Features Examples
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
## OR Logic in Filters (SearchOr)
### Basic OR Filter Example
Find all users with status "active" OR "pending":
```json
POST /users
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
}
]
}
}
```
### Combined AND/OR Filters
Find users with (status="active" OR status="pending") AND age >= 18:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "age",
"operator": "gte",
"value": 18
}
]
}
}
```
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
**Important Notes:**
- By default, filters use AND logic
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
- This grouping ensures OR conditions don't interfere with AND conditions
- You don't need to specify `"logic_operator": "AND"` as it's the default
### Multiple OR Groups
You can have multiple separate OR groups:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "priority",
"operator": "eq",
"value": "high"
},
{
"column": "priority",
"operator": "eq",
"value": "urgent",
"logic_operator": "OR"
}
]
}
}
```
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
## Custom Operators
### Simple Custom SQL Condition
Filter by email domain using custom SQL:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "company_emails",
"sql": "email LIKE '%@company.com'"
}
]
}
}
```
### Multiple Custom Operators
Combine multiple custom SQL conditions:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "recent_active",
"sql": "last_login > NOW() - INTERVAL '30 days'"
},
{
"name": "high_score",
"sql": "score > 1000"
}
]
}
}
```
### Complex Custom Operator
Use complex SQL expressions:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "priority_users",
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
}
]
}
}
```
### Combining Custom Operators with Regular Filters
Mix custom operators with standard filters:
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "country",
"operator": "eq",
"value": "USA"
}
],
"customOperators": [
{
"name": "active_last_month",
"sql": "last_activity > NOW() - INTERVAL '1 month'"
}
]
}
}
```
## Row Numbers
### Two Ways to Get Row Numbers
There are two different features for row numbers:
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
2. **`RowNumber` field in models** - Automatically number all records in the response
### 1. FetchRowNumber - Get Position of Specific Record
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
```json
{
"operation": "read",
"options": {
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response - Contains ONLY the specified user:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "Alice Smith",
"score": 9850,
"level": 42
},
"metadata": {
"total": 10000,
"count": 1,
"filtered": 10000,
"row_number": 42
}
}
```
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
### Row Number with Filters
Find position within a filtered subset (e.g., "What's my rank in my country?"):
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "country",
"operator": "eq",
"value": "USA"
},
{
"column": "status",
"operator": "eq",
"value": "active"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "Bob Johnson",
"country": "USA",
"score": 7200,
"status": "active"
},
"metadata": {
"total": 2500,
"count": 1,
"filtered": 2500,
"row_number": 156
}
}
```
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
### 2. RowNumber Field - Auto-Number All Records
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
**Model Definition:**
```go
type Player struct {
ID int64 `json:"id"`
Name string `json:"name"`
Score int64 `json:"score"`
RowNumber int64 `json:"row_number"` // Will be auto-populated
}
```
**Request (with pagination):**
```json
{
"operation": "read",
"options": {
"sort": [{"column": "score", "direction": "desc"}],
"limit": 10,
"offset": 20
}
}
```
**Response - RowNumber automatically set:**
```json
{
"success": true,
"data": [
{
"id": 456,
"name": "Player21",
"score": 8900,
"row_number": 21
},
{
"id": 789,
"name": "Player22",
"score": 8850,
"row_number": 22
},
{
"id": 123,
"name": "Player23",
"score": 8800,
"row_number": 23
}
// ... records 24-30 ...
]
}
```
**How It Works:**
- `row_number = offset + index + 1` (1-based)
- With offset=20, first record gets row_number=21
- With offset=20, second record gets row_number=22
- Perfect for displaying "Rank" in paginated tables
**Use Case:** Displaying leaderboards with rank numbers:
```
Rank | Player | Score
-----|-----------|-------
21 | Player21 | 8900
22 | Player22 | 8850
23 | Player23 | 8800
```
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
### When to Use Each Feature
| Feature | Use Case | Returns | Performance |
|---------|----------|---------|-------------|
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
**Combined Example - Full Leaderboard UI:**
```javascript
// Request 1: Get current user's rank
const userRank = await api.read({
fetch_row_number: currentUserId,
sort: [{column: "score", direction: "desc"}]
});
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
// Request 2: Get top 10 with rank numbers
const top10 = await api.read({
sort: [{column: "score", direction: "desc"}],
limit: 10,
offset: 0
});
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
// Display:
// "Your Rank: #156"
// "Top Players:"
// "#1 - Alice - 9999"
// "#2 - Bob - 9876"
// ...
```
## Complete Example: Advanced Query
Combine all features for a complex query:
```json
{
"operation": "read",
"options": {
"columns": ["id", "name", "email", "score", "status"],
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "trial",
"logic_operator": "OR"
},
{
"column": "score",
"operator": "gte",
"value": 100
}
],
"customOperators": [
{
"name": "recent_activity",
"sql": "last_login > NOW() - INTERVAL '7 days'"
},
{
"name": "verified_email",
"sql": "email_verified = true"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
},
{
"column": "created_at",
"direction": "asc"
}
],
"fetch_row_number": "12345",
"limit": 50,
"offset": 0
}
}
```
This query:
- Selects specific columns
- Filters for users with status "active" OR "trial"
- AND score >= 100
- Applies custom SQL conditions for recent activity and verified emails
- Sorts by score (descending) then creation date (ascending)
- Returns the row number of user "12345" in this filtered/sorted set
- Returns 50 records starting from the first one
## Use Cases
### 1. Leaderboards - Get Current User's Rank
Get the current user's position and data (returns only their record):
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "game_id",
"operator": "eq",
"value": "game123"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "current_user_id"
}
}
```
**Tip:** For full leaderboards, make two requests:
1. One with `fetch_row_number` to get user's rank
2. One with `limit` and `offset` to get top players list
### 2. Multi-Status Search
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "order_status",
"operator": "eq",
"value": "pending"
},
{
"column": "order_status",
"operator": "eq",
"value": "processing",
"logic_operator": "OR"
},
{
"column": "order_status",
"operator": "eq",
"value": "shipped",
"logic_operator": "OR"
}
]
}
}
```
### 3. Advanced Date Filtering
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "this_month",
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
},
{
"name": "business_hours",
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
}
]
}
}
```
## Security Considerations
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
1. **Never** directly interpolate user input into custom operator SQL
2. Always validate and sanitize custom operator SQL on the backend
3. Consider using a whitelist of allowed custom operators
4. Use prepared statements or parameterized queries when possible
5. Implement proper authorization checks before executing queries
Example of safe custom operator handling in Go:
```go
// Whitelist of allowed custom operators
allowedOperators := map[string]string{
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
"premium_only": "subscription_level = 'premium'",
}
// Validate custom operators from request
for _, op := range req.Options.CustomOperators {
if sql, ok := allowedOperators[op.Name]; ok {
op.SQL = sql // Use whitelisted SQL
} else {
return errors.New("custom operator not allowed: " + op.Name)
}
}
```

View File

@@ -214,6 +214,146 @@ Content-Type: application/json
```json
{
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "status",
"operator": "eq",
"value": "pending",
"logic_operator": "OR"
},
{
"column": "age",
"operator": "gte",
"value": 18
}
]
}
```
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
### Custom Operators
Add custom SQL conditions when needed:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"name": "email_domain_filter",
"sql": "LOWER(email) LIKE '%@example.com'"
},
{
"name": "recent_records",
"sql": "created_at > NOW() - INTERVAL '7 days'"
}
]
}
}
```
Custom operators are applied as additional WHERE conditions to your query.
### Fetch Row Number
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
}
],
"sort": [
{
"column": "score",
"direction": "desc"
}
],
"fetch_row_number": "12345"
}
}
```
**Response - Returns ONLY the specified record with its position:**
```json
{
"success": true,
"data": {
"id": 12345,
"name": "John Doe",
"score": 850,
"status": "active"
},
"metadata": {
"total": 1000,
"count": 1,
"filtered": 1000,
"row_number": 42
}
}
```
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
## Preloading
Load related entities with custom configuration:
```json
{
"operation": "read",
"options": {
"columns": ["id", "name", "email"],
"preload": [
{
"relation": "posts",
"columns": ["id", "title", "created_at"],
"filters": [
{
"column": "status",
"operator": "eq",
"value": "published"
}
],
"sort": [
{
"column": "created_at",
"direction": "desc"
}
],
"limit": 5
},
{
"relation": "profile",
"columns": ["bio", "website"]
}
]
}
}
```
## Cursor Pagination
Efficient pagination for large datasets:
### First Request (No Cursor)
```json
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
// Check permissions
if !userHasPermission(ctx.Context, ctx.Entity) {
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
return nil
}
// Modify query options
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
}
return nil
users[i].Email = maskEmail(users[i].Email)
}
})
// Register an after-read hook (e.g., for data transformation)
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
})
// Transform or filter results
if users, ok := ctx.Result.([]User); ok {
for i := range users {
users[i].Email = maskEmail(users[i].Email)
}
}
return nil
})
// Register a before-create hook (e.g., for validation)
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
// Validate data
if user, ok := ctx.Data.(*User); ok {
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Add timestamps
@@ -497,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
CreatedAt time.Time `json:"created_at"`
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
}
// Schema.Table format
handler.registry.RegisterModel("core.users", &User{})
handler.registry.RegisterModel("core.posts", &Post{})
@@ -507,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

View File

@@ -24,6 +24,7 @@ const (
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
//
// Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter(
@@ -31,8 +32,10 @@ func GetCursorFilter(
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
// Remove schema prefix if present
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -57,18 +60,19 @@ func GetCursorFilter(
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
joinSQL := ""
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
// Parse: "created_at", "user.name", etc.
// Parse: "created_at", "user.name", "fn.sortorder", etc.
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
@@ -81,7 +85,7 @@ func GetCursorFilter(
}
// Resolve column
cursorCol, targetCol, err := resolveColumn(
cursorCol, targetCol, isJoin, err := resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
@@ -89,6 +93,22 @@ func GetCursorFilter(
continue
}
// Handle joins
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
// Build inequality
op := "<"
if desc {
@@ -112,10 +132,12 @@ func GetCursorFilter(
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
@@ -136,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
return "", 0
}
// Helper: resolve column (main table only for now)
// Helper: resolve column (main table or join)
func resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, err error) {
) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Joined column (not supported in resolvespec yet)
// Joined column
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
return "", "", true, nil
}
return "", "", fmt.Errorf("invalid column: %s", field)
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
// ------------------------------------------------------------------------- //

View File

@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no cursor is provided")
}
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no sort columns are defined")
}
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -170,19 +170,50 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
options := common.RequestOptions{
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
CursorForward: "8975",
}
tableName := "core.account"
pkName := "rid_account"
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestGetActiveCursor(t *testing.T) {
tests := []struct {
name string
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
wantErr: false,
},
{
name: "Joined column (not supported)",
name: "Joined column (isJoin=true, no error)",
field: "name",
prefix: "user",
tableName: "posts",
modelColumns: []string{"id", "title"},
wantErr: true,
wantErr: false,
// cursorCol and targetCol are empty when isJoin=true; handled by caller
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr {
if err == nil {
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
// For join columns, cursor/target are empty and isJoin=true
if isJoin {
if cursor != "" || target != "" {
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
}
return
}
if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
}
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}

View File

@@ -0,0 +1,143 @@
package resolvespec
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
h := &Handler{}
tests := []struct {
name string
filter common.FilterOption
expectedCondition string
expectedArgsCount int
}{
{
name: "Equal operator",
filter: common.FilterOption{
Column: "status",
Operator: "eq",
Value: "active",
},
expectedCondition: "status = ?",
expectedArgsCount: 1,
},
{
name: "Greater than operator",
filter: common.FilterOption{
Column: "age",
Operator: "gt",
Value: 18,
},
expectedCondition: "age > ?",
expectedArgsCount: 1,
},
{
name: "IN operator",
filter: common.FilterOption{
Column: "status",
Operator: "in",
Value: []string{"active", "pending"},
},
expectedCondition: "status IN (?,?)",
expectedArgsCount: 2,
},
{
name: "LIKE operator",
filter: common.FilterOption{
Column: "email",
Operator: "like",
Value: "%@example.com",
},
expectedCondition: "email LIKE ?",
expectedArgsCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
condition, args := h.buildFilterCondition(tt.filter)
if condition != tt.expectedCondition {
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
}
if len(args) != tt.expectedArgsCount {
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
}
// Note: Skip value comparison for slices as they can't be compared with ==
// The important part is that args are populated correctly
})
}
}
// TestORGrouping tests that consecutive OR filters are properly grouped
func TestORGrouping(t *testing.T) {
// This is a conceptual test - in practice we'd need a mock SelectQuery
// to verify the actual SQL grouping behavior
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
{Column: "age", Operator: "gte", Value: 18},
}
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
// The first three filters should be grouped together
// The fourth filter should be separate with AND
// Count OR groups
orGroupCount := 0
inORGroup := false
for i := 1; i < len(filters); i++ {
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
orGroupCount++
inORGroup = true
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
inORGroup = false
}
}
// We should have detected one OR group
if orGroupCount != 1 {
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
}
})
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
{Column: "priority", Operator: "eq", Value: "high"},
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
}
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
// Should have two OR groups
orGroupCount := 0
inORGroup := false
for i := 1; i < len(filters); i++ {
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
orGroupCount++
inORGroup = true
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
inORGroup = false
}
}
// We should have detected two OR groups
if orGroupCount != 2 {
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
}
})
}

View File

@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
validator := common.NewColumnValidator(model)
req.Options = validator.FilterRequestOptions(req.Options)
// Execute BeforeHandle hook - auth check fires here, after model resolution
beforeCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Writer: w,
Request: r,
Operation: req.Operation,
}
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
code := http.StatusUnauthorized
if beforeCtx.AbortCode != 0 {
code = beforeCtx.AbortCode
}
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
return
}
switch req.Operation {
case "read":
h.handleRead(ctx, w, id, req.Options)
@@ -280,10 +300,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
}
// Apply filters
for _, filter := range options.Filters {
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
query = h.applyFilter(query, filter)
// Apply filters with proper grouping for OR logic
query = h.applyFilters(query, options.Filters)
// Apply custom operators
for _, customOp := range options.CustomOperators {
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
query = query.Where(customOp.SQL)
}
// Apply sorting
@@ -306,8 +329,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation
modelColumns := reflection.GetModelColumns(model)
// Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
@@ -318,6 +346,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
// Ensure outer parentheses to prevent OR logic from escaping
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
@@ -379,24 +409,105 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
}
// Apply pagination
if options.Limit != nil && *options.Limit > 0 {
logger.Debug("Applying limit: %d", *options.Limit)
query = query.Limit(*options.Limit)
// Handle FetchRowNumber if requested
var rowNumber *int64
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
pkName := reflection.GetPrimaryKeyName(model)
// Build ROW_NUMBER window function SQL
rowNumberSQL := "ROW_NUMBER() OVER ("
if len(options.Sort) > 0 {
rowNumberSQL += "ORDER BY "
for i, sort := range options.Sort {
if i > 0 {
rowNumberSQL += ", "
}
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
}
}
rowNumberSQL += ")"
// Create a query to fetch the row number using a subquery approach
// We'll select the PK and row_number, then filter by the target ID
type RowNumResult struct {
RowNum int64 `bun:"row_num"`
}
rowNumQuery := h.db.NewSelect().Table(tableName).
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
Column(pkName)
// Apply the same filters as the main query
for _, filter := range options.Filters {
rowNumQuery = h.applyFilter(rowNumQuery, filter)
}
// Apply custom operators
for _, customOp := range options.CustomOperators {
rowNumQuery = rowNumQuery.Where(customOp.SQL)
}
// Filter for the specific ID we want the row number for
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
// Execute query to get row number
var result RowNumResult
if err := rowNumQuery.Scan(ctx, &result); err != nil {
if err == sql.ErrNoRows {
// Build filter description for error message
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
if len(options.CustomOperators) > 0 {
customOps := make([]string, 0, len(options.CustomOperators))
for _, op := range options.CustomOperators {
customOps = append(customOps, op.SQL)
}
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
}
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
} else {
logger.Warn("Error fetching row number: %v", err)
}
} else {
rowNumber = &result.RowNum
logger.Debug("Found row number: %d", *rowNumber)
}
}
if options.Offset != nil && *options.Offset > 0 {
logger.Debug("Applying offset: %d", *options.Offset)
query = query.Offset(*options.Offset)
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
if options.Limit != nil && *options.Limit > 0 {
logger.Debug("Applying limit: %d", *options.Limit)
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
logger.Debug("Applying offset: %d", *options.Offset)
query = query.Offset(*options.Offset)
}
}
// Execute query
var result interface{}
if id != "" {
logger.Debug("Querying single record with ID: %s", id)
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
// Single record query - either by URL ID or FetchRowNumber
var targetID string
if id != "" {
targetID = id
logger.Debug("Querying single record with URL ID: %s", id)
} else {
targetID = *options.FetchRowNumber
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
}
// For single record, create a new pointer to the struct type
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
@@ -416,20 +527,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
logger.Info("Successfully retrieved records")
// Build metadata
limit := 0
if options.Limit != nil {
limit = *options.Limit
}
offset := 0
if options.Offset != nil {
offset = *options.Offset
count := int64(total)
// When FetchRowNumber is used, we only return 1 record
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
count = 1
// Set the fetched row number on the record
if rowNumber != nil {
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
}
} else {
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Set row numbers on records if RowNumber field exists
// Only for multiple records (not when fetching single record)
h.setRowNumbersOnRecords(result, offset)
}
h.sendResponse(w, result, &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Limit: limit,
Offset: offset,
Total: int64(total),
Filtered: int64(total),
Count: count,
Limit: limit,
Offset: offset,
RowNumber: rowNumber,
})
}
@@ -701,97 +831,130 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
// Get the primary key name
pkName := reflection.GetPrimaryKeyName(model)
// First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := h.db.NewSelect().Model(existingRecord)
// Wrap in transaction to ensure BeforeUpdate hook is inside transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
// Apply conditions to select
if urlID != "" {
logger.Debug("Updating by URL ID: %s", urlID)
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
logger.Debug("Updating by request ID: %s", id)
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
if len(id) > 0 {
logger.Debug("Updating by multiple IDs: %v", id)
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
// Apply conditions to select
if urlID != "" {
logger.Debug("Updating by URL ID: %s", urlID)
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
logger.Debug("Updating by request ID: %s", id)
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
if len(id) > 0 {
logger.Debug("Updating by multiple IDs: %v", id)
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
}
}
}
}
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
logger.Warn("No records found to update")
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
return
}
logger.Error("Error fetching existing record: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err)
return
}
// Convert existing record to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
logger.Error("Error marshaling existing record: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err)
return
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
logger.Error("Error unmarshaling existing record: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err)
return
}
// Merge only non-null and non-empty values from the incoming request into the existing record
for key, newValue := range updates {
// Skip if the value is nil
if newValue == nil {
continue
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Skip if the value is an empty string
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
// Convert existing record to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
// Update the existing map with the new value
existingMap[key] = newValue
}
// Build update query with merged data
query := h.db.NewUpdate().Table(tableName).SetMap(existingMap)
// Apply conditions
if urlID != "" {
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
// Execute BeforeUpdate hooks inside transaction
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Options: options,
ID: urlID,
Data: updates,
Writer: w,
Tx: tx,
}
}
result, err := query.Exec(ctx)
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
}
// Use potentially modified data from hook context
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge only non-null and non-empty values from the incoming request into the existing record
for key, newValue := range updates {
// Skip if the value is nil
if newValue == nil {
continue
}
// Skip if the value is an empty string
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
// Update the existing map with the new value
existingMap[key] = newValue
}
// Build update query with merged data
query := tx.NewUpdate().Table(tableName).SetMap(existingMap)
// Apply conditions
if urlID != "" {
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
}
}
result, err := query.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record(s): %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
// Execute AfterUpdate hooks inside transaction
hookCtx.Result = updates
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
return fmt.Errorf("AfterUpdate hook failed: %w", err)
}
return nil
})
if err != nil {
logger.Error("Update error: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
if err.Error() == "no records found to update" {
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err)
} else {
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
}
return
}
if result.RowsAffected() == 0 {
logger.Warn("No records found to update")
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
return
}
logger.Info("Successfully updated %d records", result.RowsAffected())
logger.Info("Successfully updated record(s)")
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
@@ -849,9 +1012,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range updates {
if itemID, ok := item["id"]; ok {
itemIDStr := fmt.Sprintf("%v", itemID)
// First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
continue // Skip if record not found
@@ -869,6 +1034,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return fmt.Errorf("failed to unmarshal existing record: %w", err)
}
// Execute BeforeUpdate hooks inside transaction
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Options: options,
ID: itemIDStr,
Data: item,
Writer: w,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
}
// Use potentially modified data from hook context
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
item = modifiedData
}
// Merge only non-null and non-empty values
for key, newValue := range item {
if newValue == nil {
@@ -884,6 +1072,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
if _, err := txQuery.Exec(ctx); err != nil {
return err
}
// Execute AfterUpdate hooks inside transaction
hookCtx.Result = item
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
}
}
}
return nil
@@ -957,9 +1152,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
for _, item := range updates {
if itemMap, ok := item.(map[string]interface{}); ok {
if itemID, ok := itemMap["id"]; ok {
itemIDStr := fmt.Sprintf("%v", itemID)
// First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
continue // Skip if record not found
@@ -977,6 +1174,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return fmt.Errorf("failed to unmarshal existing record: %w", err)
}
// Execute BeforeUpdate hooks inside transaction
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Options: options,
ID: itemIDStr,
Data: itemMap,
Writer: w,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
}
// Use potentially modified data from hook context
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
itemMap = modifiedData
}
// Merge only non-null and non-empty values
for key, newValue := range itemMap {
if newValue == nil {
@@ -992,6 +1212,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
if _, err := txQuery.Exec(ctx); err != nil {
return err
}
// Execute AfterUpdate hooks inside transaction
hookCtx.Result = itemMap
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
}
list = append(list, item)
}
}
@@ -1033,6 +1261,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
logger.Info("Deleting records from %s.%s", schema, entity)
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
ID: id,
Data: data,
Writer: w,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Error("BeforeDelete hook failed: %v", err)
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
return
}
// Handle batch delete from request data
if data != nil {
switch v := data.(type) {
@@ -1203,29 +1449,165 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
h.sendResponse(w, recordToDelete, nil)
}
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
// applyFilters applies all filters with proper grouping for OR logic
// Groups consecutive OR filters together to ensure proper query precedence
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
// Check if this starts an OR group (current or next filter has OR logic)
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
// Collect all consecutive filters that are OR'd together
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
// Apply the OR group as a single grouped WHERE clause
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
// Single filter with AND logic (or first filter)
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
// applyFilterGroup applies a group of filters that should be OR'd together
// Always wraps them in parentheses and applies as a single WHERE clause
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
// Build all conditions and collect args
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
// Single filter - no need for grouping
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
// Multiple conditions - group with parentheses and OR
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
return query.Where(groupedCondition, args...)
}
// buildFilterCondition builds a filter condition and returns it with args
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
var condition string
var args []interface{}
switch filter.Operator {
case "eq":
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
case "neq":
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
case "gt":
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
case "gte":
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
case "lt":
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
case "lte":
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "ilike":
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value)
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in":
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
condition, args = common.BuildInCondition(filter.Column, filter.Value)
if condition == "" {
return "", nil
}
default:
return "", nil
}
return condition, args
}
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
// Determine which method to use based on LogicOperator
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
var condition string
var args []interface{}
switch filter.Operator {
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "ilike":
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in":
condition, args = common.BuildInCondition(filter.Column, filter.Value)
if condition == "" {
return query
}
default:
return query
}
// Apply filter with appropriate logic operator
if useOrLogic {
return query.WhereOr(condition, args...)
}
return query.Where(condition, args...)
}
// parseTableName splits a table name that may contain schema into separate schema and table
@@ -1280,10 +1662,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
// getTableName returns the full table name including schema.
// For most drivers the result is "schema.table". For SQLite, which does not
// support schema-qualified names, the schema and table are joined with an
// underscore: "schema_table".
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
@@ -1558,6 +1946,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: preloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
// Ensure outer parentheses to prevent OR logic from escaping
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
@@ -1601,6 +1991,51 @@ func toSnakeCase(s string) string {
return strings.ToLower(result.String())
}
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
// The row number is calculated as offset + index + 1 (1-based)
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
// Get the reflect value of the records
recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr {
recordsValue = recordsValue.Elem()
}
// Ensure it's a slice
if recordsValue.Kind() != reflect.Slice {
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
return
}
// Iterate through each record
for i := 0; i < recordsValue.Len(); i++ {
record := recordsValue.Index(i)
// Dereference if it's a pointer
if record.Kind() == reflect.Ptr {
if record.IsNil() {
continue
}
record = record.Elem()
}
// Ensure it's a struct
if record.Kind() != reflect.Struct {
continue
}
// Try to find and set the RowNumber field
rowNumberField := record.FieldByName("RowNumber")
if rowNumberField.IsValid() && rowNumberField.CanSet() {
// Check if the field is of type int64
if rowNumberField.Kind() == reflect.Int64 {
rowNum := int64(offset + i + 1)
rowNumberField.SetInt(rowNum)
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
}
}
}
}
// HandleOpenAPI generates and returns the OpenAPI specification
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
if h.openAPIGenerator == nil {

View File

@@ -12,6 +12,10 @@ import (
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
// Use this for auth checks that need model rules and user context simultaneously.
BeforeHandle HookType = "before_handle"
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
@@ -43,6 +47,9 @@ type HookContext struct {
Writer common.ResponseWriter
Request common.Request
// Operation being dispatched (e.g. "read", "create", "update", "delete")
Operation string
// Operation-specific fields
ID string
Data interface{} // For create/update operations

View File

@@ -50,8 +50,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
handler.HandleOpenAPI(respAdapter, reqAdapter)
})
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
@@ -69,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
// Create handler functions for this specific entity
postEntityHandler := createMuxHandler(handler, schema, entity, "")
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
postEntityHandler = authMiddleware(postEntityHandler)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
getEntityHandler = authMiddleware(getEntityHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
@@ -98,7 +99,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
@@ -106,7 +108,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.Handle(respAdapter, reqAdapter, vars)
}
}
@@ -117,7 +119,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
@@ -125,7 +128,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
@@ -137,13 +140,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
corsConfig := common.DefaultCORSConfig()
corsConfig.AllowedMethods = allowedMethods
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
// Return metadata in the OPTIONS response body
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
@@ -212,9 +216,34 @@ type BunRouterHandler interface {
Handle(method, path string, handler bunrouter.HandlerFunc)
}
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
if authMiddleware == nil {
return handler
}
return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
})
// Wrap with auth middleware and execute
wrappedHandler := authMiddleware(httpHandler)
wrappedHandler.ServeHTTP(w, req.Request)
return nil
}
}
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
// Accepts bunrouter.Router or bunrouter.Group
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
// authMiddleware is optional - if provided, routes will be protected with the middleware
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
// CORS config
corsConfig := common.DefaultCORSConfig()
@@ -222,15 +251,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
// Add global /openapi route
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
handler.HandleOpenAPI(respAdapter, reqAdapter)
return nil
})
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
return nil
})
@@ -251,85 +281,97 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
currentEntity := entity
// POST route without ID
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
// POST route with ID
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
// GET route without ID
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
// GET route with ID
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
// OPTIONS route without ID (returns metadata)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewHTTPRequest(req.Request)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route with ID (returns metadata)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewHTTPRequest(req.Request)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
@@ -344,8 +386,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
// Create bunrouter
bunRouter := bunrouter.New()
// Setup ResolveSpec routes with bunrouter
SetupBunRouterRoutes(bunRouter, handler)
// Setup ResolveSpec routes with bunrouter without authentication
SetupBunRouterRoutes(bunRouter, handler, nil)
// Start server
// http.ListenAndServe(":8080", bunRouter)
@@ -366,8 +408,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
// Create bunrouter
bunRouter := bunrouter.New()
// Setup ResolveSpec routes
SetupBunRouterRoutes(bunRouter, handler)
// Setup ResolveSpec routes without authentication
SetupBunRouterRoutes(bunRouter, handler, nil)
// This gives you the full uptrace stack: bunrouter + Bun ORM
// http.ListenAndServe(":8080", bunRouter)
@@ -385,8 +427,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
apiGroup := bunRouter.NewGroup("/api")
// Setup ResolveSpec routes on the group - routes will be under /api
SetupBunRouterRoutes(apiGroup, handler)
SetupBunRouterRoutes(apiGroup, handler, nil)
// Start server
// http.ListenAndServe(":8080", bunRouter)
}
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
func ExampleWithGORMAndAuth(db *gorm.DB) {
// Create handler using GORM
_ = NewHandlerWithGORM(db)
// Create auth middleware
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// Setup router with authentication
_ = mux.NewRouter()
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
// Register models
// handler.RegisterModel("public", "users", &User{})
// Start server
// http.ListenAndServe(":8080", muxRouter)
}
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
func ExampleWithBunAndAuth(bunDB *bun.DB) {
// Create Bun adapter
dbAdapter := database.NewBunAdapter(bunDB)
// Create model registry
registry := modelregistry.NewModelRegistry()
// registry.RegisterModel("public.users", &User{})
// Create handler
_ = NewHandler(dbAdapter, registry)
// Create auth middleware
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// Setup routes with authentication
_ = mux.NewRouter()
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
// Start server
// http.ListenAndServe(":8080", muxRouter)
}
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
// Create Bun database adapter
dbAdapter := database.NewBunAdapter(bunDB)
// Create model registry
registry := modelregistry.NewModelRegistry()
// registry.RegisterModel("public.users", &User{})
// Create handler with Bun
_ = NewHandler(dbAdapter, registry)
// Create auth middleware
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// Create bunrouter
_ = bunrouter.New()
// Setup ResolveSpec routes with authentication
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
// http.ListenAndServe(":8080", bunRouter)
}

View File

@@ -2,6 +2,7 @@ package resolvespec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -10,6 +11,17 @@ import (
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
return security.LogDataAccess(secCtx)
})
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelUpdateAllowed(secCtx)
})
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelDeleteAllowed(secCtx)
})
logger.Info("Security hooks registered for resolvespec handler")
}

View File

@@ -214,14 +214,46 @@ x-expand: department:id,name,code
**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation.
#### `x-custom-sql-join`
Raw SQL JOIN statement.
Custom SQL JOIN clauses for joining tables in queries.
**Format:** SQL JOIN clause
**Format:** SQL JOIN clause or multiple clauses separated by `|`
**Single JOIN:**
```
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
```
⚠️ **Note:** Not yet fully implemented.
**Multiple JOINs:**
```
x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id
```
**Features:**
- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS)
- Multiple JOINs can be specified using the pipe `|` separator
- JOINs are sanitized for security
- Can be specified via headers or query parameters
- **Table aliases are automatically extracted and allowed for filtering and sorting**
**Using Join Aliases in Filters and Sorts:**
When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters:
```
# Join with alias
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
# Sort by joined table column
x-sort: d.name,employees.id
# Filter by joined table column
x-searchop-eq-d.name: Engineering
```
The system automatically:
1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`)
2. Validates that prefixed columns (like `d.name`) refer to valid join aliases
3. Allows these prefixed columns in filters and sorts
---

View File

@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
```
**Available Hook Types**:
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
* `BeforeRead`, `AfterRead`
* `BeforeCreate`, `AfterCreate`
* `BeforeUpdate`, `AfterUpdate`
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
* `Handler`: Access to handler, database, and registry
* `Schema`, `Entity`, `TableName`: Request info
* `Model`: The registered model type
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
* `Options`: Parsed request options (filters, sorting, etc.)
* `ID`: Record ID (for single-record operations)
* `Data`: Request data (for create/update)
* `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
## Cursor Pagination

View File

@@ -26,6 +26,7 @@ type queryCacheKey struct {
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
CustomSQLJoin []string `json:"custom_sql_join,omitempty"`
Expand []expandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
@@ -40,7 +41,7 @@ type cachedTotal struct {
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := queryCacheKey{
TableName: tableName,
@@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
CustomSQLJoin: customJoin,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
@@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))

View File

@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) {
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -62,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
@@ -91,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
@@ -127,7 +135,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
joinSQL,
pkName,
cursorID,

View File

@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
}
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "fn.sortorder", Direction: "ASC"},
},
},
}
opts.CursorForward = "8975"
tableName := "core.account"
pkName := "rid_account"
// modelColumns does not contain "sortorder" - it's a lateral join computed column
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
// Should contain the rewritten lateral join inside the EXISTS subquery
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
// Should compare fn.sortorder values
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
// Should NOT contain empty comparison like "< "
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > posts.priority",

View File

@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
// Derive operation for auth check
var operation string
switch method {
case "GET":
operation = "read"
case "POST":
operation = "create"
case "PUT", "PATCH":
operation = "update"
case "DELETE":
operation = "delete"
default:
operation = "read"
}
// Execute BeforeHandle hook - auth check fires here, after model resolution
beforeCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Writer: w,
Request: r,
Operation: operation,
}
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
code := http.StatusUnauthorized
if beforeCtx.AbortCode != 0 {
code = beforeCtx.AbortCode
}
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
return
}
switch method {
case "GET":
if id != "" {
@@ -435,9 +470,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
// Apply preloading
logger.Debug("Total preloads to apply: %d", len(options.Preload))
for idx := range options.Preload {
preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation)
logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
@@ -463,7 +500,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
// Apply filters - validate and adjust for column types first
for i := range options.Filters {
// Group consecutive OR filters together to prevent OR logic from escaping
for i := 0; i < len(options.Filters); {
filter := &options.Filters[i]
// Validate and adjust filter based on column type
@@ -475,8 +513,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
logicOp = "AND"
}
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
// Check if this is the start of an OR group
if logicOp == "OR" {
// Collect all consecutive OR filters
orFilters := []*common.FilterOption{filter}
orCastInfo := []ColumnCastInfo{castInfo}
j := i + 1
for j < len(options.Filters) {
nextFilter := &options.Filters[j]
nextLogicOp := nextFilter.LogicOperator
if nextLogicOp == "" {
nextLogicOp = "AND"
}
if nextLogicOp == "OR" {
nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model)
orFilters = append(orFilters, nextFilter)
orCastInfo = append(orCastInfo, nextCastInfo)
j++
} else {
break
}
}
// Apply the OR group as a single grouped condition
logger.Debug("Applying OR filter group with %d conditions", len(orFilters))
query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName)
i = j
} else {
// Single AND filter - apply normally
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
i++
}
}
// Apply custom SQL WHERE clause (AND condition)
@@ -486,6 +555,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
// Ensure outer parentheses to prevent OR logic from escaping
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
@@ -497,13 +568,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
// Ensure outer parentheses to prevent OR logic from escaping
sanitizedOr = common.EnsureOuterParentheses(sanitizedOr)
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
}
// If ID is provided, filter by ID
if id != "" {
// Apply custom SQL JOIN clauses
if len(options.CustomSQLJoin) > 0 {
for _, joinClause := range options.CustomSQLJoin {
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
// Joins are already sanitized during parsing, so we can apply them directly
query = query.Join(joinClause)
}
}
// Handle FetchRowNumber before applying ID filter
// This must happen before the query to get the row position, then filter by PK
var fetchedRowNumber *int64
var fetchRowNumberPKValue string
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
pkName := reflection.GetPrimaryKeyName(model)
fetchRowNumberPKValue = *options.FetchRowNumber
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
if err != nil {
logger.Error("Failed to fetch row number: %v", err)
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
return
}
fetchedRowNumber = &rowNum
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
// Now filter the main query to this specific primary key
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), fetchRowNumberPKValue)
} else if id != "" {
// If ID is provided (and not FetchRowNumber), filter by ID
pkName := reflection.GetPrimaryKeyName(model)
logger.Debug("Filtering by ID=%s: %s", pkName, id)
@@ -552,6 +656,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
options.Sort,
options.CustomSQLWhere,
options.CustomSQLOr,
options.CustomSQLJoin,
expandOpts,
options.Distinct,
options.CursorForward,
@@ -618,12 +723,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation using the generic database function
modelColumns := reflection.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
// Build expand joins map: custom SQL joins are available in cursor subquery
expandJoins := make(map[string]string)
for _, joinClause := range options.CustomSQLJoin {
alias := extractJoinAlias(joinClause)
if alias != "" {
expandJoins[alias] = joinClause
}
}
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL
@@ -682,7 +794,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
// Set row numbers on each record if the model has a RowNumber field
h.setRowNumbersOnRecords(modelPtr, offset)
// If FetchRowNumber was used, set the fetched row number instead of offset-based
if fetchedRowNumber != nil {
// FetchRowNumber: set the actual row position on the record
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
} else {
h.setRowNumbersOnRecords(modelPtr, offset)
}
metadata := &common.Metadata{
Total: int64(total),
@@ -692,21 +811,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset,
}
// Fetch row number for a specific record if requested
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
pkName := reflection.GetPrimaryKeyName(model)
pkValue := *options.FetchRowNumber
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
if err != nil {
logger.Warn("Failed to fetch row number: %v", err)
// Don't fail the entire request, just log the warning
} else {
metadata.RowNumber = &rowNum
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
}
// If FetchRowNumber was used, also set it in metadata
if fetchedRowNumber != nil {
metadata.RowNumber = fetchedRowNumber
logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
}
// Execute AfterRead hooks
@@ -836,6 +944,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
}
}
// Apply custom SQL joins from XFiles
if len(preload.SqlJoins) > 0 {
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
for _, joinClause := range preload.SqlJoins {
sq = sq.Join(joinClause)
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
}
}
// Apply filters
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
@@ -861,10 +978,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
if len(preload.Where) > 0 {
// Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads}
// First add table prefixes to unqualified columns
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
// Then sanitize and allow preload table prefixes
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
// Determine the table name to use for WHERE clause processing
// Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
tableName := preload.TableName
if tableName == "" {
tableName = reflection.ExtractTableNameOnly(preload.Relation)
}
// In Bun's Relation context, table prefixes are only needed when there are JOINs
// Without JOINs, Bun already knows which table is being queried
whereClause := preload.Where
if len(preload.SqlJoins) > 0 {
// Has JOINs: add table prefixes to disambiguate columns
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
}
// Sanitize the WHERE clause and allow preload table prefixes
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
@@ -883,21 +1015,82 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
})
// Handle recursive preloading
if preload.Recursive && depth < 5 {
if preload.Recursive && depth < 8 {
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
// For recursive relationships, we need to get the last part of the relation path
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
relationParts := strings.Split(preload.Relation, ".")
lastRelationName := relationParts[len(relationParts)-1]
// Create a recursive preload with the same configuration
// but with the relation path extended
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + lastRelationName
// Generate FK-based relation name for children
// Use RecursiveChildKey if available, otherwise fall back to RelatedKey
recursiveFK := preload.RecursiveChildKey
if recursiveFK == "" {
recursiveFK = preload.RelatedKey
}
// Recursively apply preload until we reach depth 5
recursiveRelationName := lastRelationName
if recursiveFK != "" {
// Check if the last relation name already contains the FK suffix
// (this happens when XFiles already generated the FK-based name)
fkUpper := strings.ToUpper(recursiveFK)
expectedSuffix := "_" + fkUpper
if strings.HasSuffix(lastRelationName, expectedSuffix) {
// Already has FK suffix, just reuse the same name
recursiveRelationName = lastRelationName
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
} else {
// Generate FK-based name
recursiveRelationName = lastRelationName + expectedSuffix
keySource := "RelatedKey"
if preload.RecursiveChildKey != "" {
keySource = "RecursiveChildKey"
}
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
keySource, recursiveRelationName, recursiveFK)
}
} else {
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
preload.Relation, preload.Relation, lastRelationName)
}
// Create recursive preload
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
recursivePreload.Recursive = false // Prevent infinite recursion at this level
// Use the recursive FK for child relations, not the parent's RelatedKey
if preload.RecursiveChildKey != "" {
recursivePreload.RelatedKey = preload.RecursiveChildKey
}
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
recursivePreload.Where = ""
recursivePreload.Filters = []common.FilterOption{}
logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d",
recursivePreload.Relation, depth+1)
// Apply recursively up to depth 8
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
// ALSO: Extend any child relations (like DEF) to recursive levels
baseRelation := preload.Relation + "."
for i := range allPreloads {
relatedPreload := allPreloads[i]
if strings.HasPrefix(relatedPreload.Relation, baseRelation) &&
!strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") {
childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation)
extendedChildPreload := relatedPreload
extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName
extendedChildPreload.Recursive = false
logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d",
relatedPreload.Relation, extendedChildPreload.Relation, depth+1)
query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1)
}
}
}
return query
@@ -1110,30 +1303,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
logger.Info("Updating record in %s.%s", schema, entity)
// Execute BeforeUpdate hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Tx: h.db,
Model: model,
Options: options,
ID: id,
Data: data,
Writer: w,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
logger.Error("BeforeUpdate hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified data from hook context
data = hookCtx.Data
// Convert data to map
dataMap, ok := data.(map[string]interface{})
if !ok {
@@ -1167,6 +1336,9 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// Variable to store the updated record
var updatedRecord interface{}
// Declare hook context to be used inside and outside transaction
var hookCtx *HookContext
// Process nested relations if present
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Create temporary nested processor with transaction
@@ -1174,7 +1346,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found with ID: %v", targetID)
@@ -1204,6 +1376,30 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
nestedRelations = relations
}
// Execute BeforeUpdate hooks inside transaction
hookCtx = &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Tx: tx,
Model: model,
Options: options,
ID: id,
Data: dataMap,
Writer: w,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
}
// Use potentially modified data from hook context
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
dataMap = modifiedData
}
// Merge only non-null and non-empty values from the incoming request into the existing record
for key, newValue := range dataMap {
// Skip if the value is nil
@@ -1344,8 +1540,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
continue
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
}
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
@@ -1418,8 +1614,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
continue
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
}
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
@@ -1476,8 +1672,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
continue
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
}
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
@@ -1791,10 +1987,46 @@ func (h *Handler) processChildRelationsForField(
parentIDs[baseName] = parentID
}
// Determine which field name to use for setting parent ID in child data
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
var foreignKeyFieldName string
if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
if foreignKeyFieldName == "" {
// Fallback to lowercase field name
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
}
} else {
// Fallback: use parent's primary key name
parentPKName := reflection.GetPrimaryKeyName(parentModelType)
foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName)
if foreignKeyFieldName == "" {
foreignKeyFieldName = strings.ToLower(parentPKName)
}
}
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
childPKName := reflection.GetPrimaryKeyName(relatedModel)
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
if childPKFieldName == "" {
childPKFieldName = strings.ToLower(childPKName)
}
logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s, childPK=%s",
foreignKeyFieldName, parentID, relInfo.ForeignKey, childPKFieldName)
// Process based on relation type and data structure
switch v := relationValue.(type) {
case map[string]interface{}:
// Single related object
// Single related object - add parent ID to foreign key field
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
v[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
}
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process single relation: %w", err)
@@ -1804,6 +2036,14 @@ func (h *Handler) processChildRelationsForField(
// Multiple related objects
for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok {
// Add parent ID to foreign key field
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
}
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
@@ -1814,6 +2054,14 @@ func (h *Handler) processChildRelationsForField(
case []map[string]interface{}:
// Multiple related objects (typed slice)
for i, itemMap := range v {
// Add parent ID to foreign key field
// IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
} else if foreignKeyFieldName == childPKFieldName {
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
}
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
@@ -1827,11 +2075,18 @@ func (h *Handler) processChildRelationsForField(
return nil
}
// getTableNameForRelatedModel gets the table name for a related model
// getTableNameForRelatedModel gets the table name for a related model.
// If the model's TableName() is schema-qualified (e.g. "public.users") the
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
if schema, table := h.parseTableName(tableName); schema != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schema, table)
}
}
return tableName
}
}
@@ -1898,7 +2153,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
// Column is already cast to TEXT if needed
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
case "in":
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
if cond == "" {
return query
}
return applyWhere(cond, inArgs...)
case "between":
// Handle between operator - exclusive (> val1 AND < val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
@@ -1931,6 +2190,100 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
}
}
// applyOrFilterGroup applies a group of OR filters as a single grouped condition
// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping
func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery {
if len(filters) == 0 {
return query
}
// Build individual filter conditions
conditions := []string{}
args := []interface{}{}
for i, filter := range filters {
// Qualify the column name with table name if not already qualified
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
// Apply casting to text if needed for non-numeric columns or non-numeric values
if castInfo[i].NeedsCast {
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
}
// Build the condition based on operator
condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
// Join all conditions with OR and wrap in parentheses
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
logger.Debug("Applying grouped OR conditions: %s", groupedCondition)
// Apply as AND condition (the OR is already inside the parentheses)
return query.Where(groupedCondition, args...)
}
// buildFilterCondition builds a single filter condition and returns the condition string and args
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
switch strings.ToLower(filter.Operator) {
case "eq", "equals", "=":
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
case "neq", "not_equals", "ne", "!=", "<>":
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
case "gt", "greater_than", ">":
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
case "gte", "greater_than_equals", "ge", ">=":
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
case "lt", "less_than", "<":
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
case "lte", "less_than_equals", "le", "<=":
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
case "in":
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
return cond, inArgs
case "between":
// Handle between operator - exclusive (> val1 AND < val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
}
logger.Warn("Invalid BETWEEN filter value format")
return "", nil
case "between_inclusive":
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
}
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
return "", nil
case "is_null", "isnull":
// Check for NULL values - don't use cast for NULL checks
colName := h.qualifyColumnName(filter.Column, tableName)
return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil
case "is_not_null", "isnotnull":
// Check for NOT NULL values - don't use cast for NULL checks
colName := h.qualifyColumnName(filter.Column, tableName)
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
}
}
// parseTableName splits a table name that may contain schema into separate schema and table
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
@@ -1983,10 +2336,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
// getTableName returns the full table name including schema.
// For most drivers the result is "schema.table". For SQLite, which does not
// support schema-qualified names, the schema and table are joined with an
// underscore: "schema_table".
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
@@ -2308,21 +2667,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
}
// Build WHERE clauses from filters
whereClauses := make([]string, 0)
for i := range options.Filters {
filter := &options.Filters[i]
whereClause := h.buildFilterSQL(filter, tableName)
if whereClause != "" {
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
}
}
// Combine WHERE clauses
whereSQL := ""
if len(whereClauses) > 0 {
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
}
// Build WHERE clause from filters with proper OR grouping
whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
// Add custom SQL WHERE if provided
if options.CustomSQLWhere != "" {
@@ -2370,19 +2716,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
var result []struct {
RN int64 `bun:"rn"`
}
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
err := h.db.Query(ctx, &result, queryStr, pkValue)
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
if err != nil {
return 0, fmt.Errorf("failed to fetch row number: %w", err)
}
if len(result) == 0 {
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
whereInfo := "none"
if whereSQL != "" {
whereInfo = whereSQL
}
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
}
return result[0].RN, nil
}
// buildFilterSQL converts a filter to SQL WHERE clause string
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
// Groups consecutive OR filters together to ensure proper SQL precedence
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
if len(filters) == 0 {
return ""
}
var groups []string
i := 0
for i < len(filters) {
// Check if this starts an OR group (next filter has OR logic)
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
// Collect all consecutive filters that are OR'd together
orGroup := []string{}
// Add current filter
filterSQL := h.buildFilterSQL(&filters[i], tableName)
if filterSQL != "" {
orGroup = append(orGroup, filterSQL)
}
// Collect remaining OR filters
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
filterSQL := h.buildFilterSQL(&filters[j], tableName)
if filterSQL != "" {
orGroup = append(orGroup, filterSQL)
}
j++
}
// Group OR filters with parentheses
if len(orGroup) > 0 {
if len(orGroup) == 1 {
groups = append(groups, orGroup[0])
} else {
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
}
}
i = j
} else {
// Single filter with AND logic (or first filter)
filterSQL := h.buildFilterSQL(&filters[i], tableName)
if filterSQL != "" {
groups = append(groups, filterSQL)
}
i++
}
}
if len(groups) == 0 {
return ""
}
return "WHERE " + strings.Join(groups, " AND ")
}
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
@@ -2473,6 +2886,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
// Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
filtered.RequestOptions.JoinAliases = options.JoinAliases
// Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)

View File

@@ -26,7 +26,9 @@ type ExtendedRequestOptions struct {
CustomSQLOr string
// Joins
Expand []ExpandOption
Expand []ExpandOption
CustomSQLJoin []string // Custom SQL JOIN clauses
JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation
// Advanced features
AdvancedSQL map[string]string // Column -> SQL expression
@@ -46,7 +48,8 @@ type ExtendedRequestOptions struct {
AtomicTransaction bool
// X-Files configuration - comprehensive query options as a single JSON object
XFiles *XFiles
XFiles *XFiles
XFilesPresent bool // Flag to indicate if X-Files header was provided
}
// ExpandOption represents a relation expansion configuration
@@ -111,6 +114,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
CustomSQLJoin: make([]string, 0),
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
}
@@ -185,8 +189,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
case strings.HasPrefix(key, "x-expand"):
h.parseExpand(&options, decodedValue)
case strings.HasPrefix(key, "x-custom-sql-join"):
// TODO: Implement custom SQL join
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
h.parseCustomSQLJoin(&options, decodedValue)
// Sorting & Pagination
case strings.HasPrefix(key, "x-sort"):
@@ -271,7 +274,10 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
}
// Resolve relation names (convert table names to field names) if model is provided
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model)
}
@@ -354,6 +360,12 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
operator := parts[0]
colName := parts[1]
if strings.HasPrefix(colName, "cql") {
// Computed column - Will not filter on it
logger.Warn("Search operators on computed columns are not supported: %s", colName)
return
}
// Map operator names to filter operators
filterOp := h.mapSearchOperator(colName, operator, value)
@@ -489,6 +501,112 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) {
}
}
// parseCustomSQLJoin parses x-custom-sql-join header
// Format: Single JOIN clause or multiple JOIN clauses separated by |
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id"
func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) {
if value == "" {
return
}
// Split by | for multiple joins
joins := strings.Split(value, "|")
for _, joinStr := range joins {
joinStr = strings.TrimSpace(joinStr)
if joinStr == "" {
continue
}
// Basic validation: should contain "JOIN" keyword
upperJoin := strings.ToUpper(joinStr)
if !strings.Contains(upperJoin, "JOIN") {
logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr)
continue
}
// Sanitize the join clause using common.SanitizeWhereClause
// Note: This is basic sanitization - in production you may want stricter validation
sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil)
if sanitizedJoin == "" {
logger.Warn("Custom SQL join failed sanitization: %s", joinStr)
continue
}
// Extract table alias from the JOIN clause
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
options.JoinAliases = append(options.JoinAliases, alias)
// Also add to the embedded RequestOptions for validation
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
logger.Debug("Extracted join alias: %s", alias)
}
logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
}
}
// extractJoinAlias extracts the table alias from a JOIN clause
// Examples:
// - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r"
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
func extractJoinAlias(joinClause string) string {
upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position
joinIdx := strings.Index(upperJoin, "JOIN")
if joinIdx == -1 {
return ""
}
// Lateral joins: alias is the word after the closing ) and before ON
if strings.Contains(upperJoin, "LATERAL") {
lastClose := strings.LastIndex(joinClause, ")")
if lastClose != -1 {
words := strings.Fields(joinClause[lastClose+1:])
// words should be like ["fn", "on", "true"] or ["on", "true"]
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
return words[0]
}
}
return ""
}
// Regular joins: find the "ON" keyword position (first occurrence)
onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 {
return ""
}
// Extract the part between JOIN and ON
betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx])
// Split by spaces to get words
words := strings.Fields(betweenJoinAndOn)
if len(words) == 0 {
return ""
}
// If there's an AS keyword, the alias is after it
for i, word := range words {
if strings.EqualFold(word, "AS") && i+1 < len(words) {
return words[i+1]
}
}
// Otherwise, the alias is the last word (if there are 2+ words)
// Format: "table_name alias" or just "table_name"
if len(words) >= 2 {
return words[len(words)-1]
}
// Only one word means it's just the table name, no alias
return ""
}
// parseSorting parses x-sort header
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
@@ -590,6 +708,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
// Store the original XFiles for reference
options.XFiles = &xfiles
options.XFilesPresent = true // Mark that X-Files header was provided
// Map XFiles fields to ExtendedRequestOptions
@@ -757,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path
currentModel := model
for _, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part)
for partIdx, part := range parts {
isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level
@@ -874,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable
}
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
@@ -881,11 +1108,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
return
}
// Store the table name as-is for now - it will be resolved to field name later
// when we have the model instance available
relationPath := xfile.TableName
// Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
// Fall back to TableName if Prefix is not specified
relationName := xfile.Prefix
if relationName == "" {
relationName = xfile.TableName
}
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
// Check if this is a self-referencing recursive relation (same table as parent)
// by comparing the last part of basePath with the current prefix
basePathParts := strings.Split(basePath, ".")
lastPrefix := basePathParts[len(basePathParts)-1]
if lastPrefix == relationName {
// This is a recursive self-reference, use FK-based name
fkUpper := strings.ToUpper(xfile.RelatedKey)
relationName = relationName + "_" + fkUpper
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
}
}
relationPath := relationName
if basePath != "" {
relationPath = basePath + "." + xfile.TableName
relationPath = basePath + "." + relationName
}
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
@@ -893,6 +1142,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Create PreloadOption from XFiles configuration
preloadOpt := common.PreloadOption{
Relation: relationPath,
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
Columns: xfile.Columns,
OmitColumns: xfile.OmitColumns,
}
@@ -932,15 +1182,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Add WHERE clause if SQL conditions specified
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
// Process each SQL condition: add table prefixes and sanitize
var sqlAndOpts *common.RequestOptions
if len(preloadOpt.JoinAliases) > 0 {
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
}
for _, sqlCond := range xfile.SqlAnd {
// First add table prefixes to unqualified columns
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
// Then sanitize the condition
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond)
}
@@ -985,13 +1262,46 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
}
// Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false
if len(xfile.ChildTables) > 0 {
for _, childTable := range xfile.ChildTables {
if childTable.Recursive && childTable.TableName == xfile.TableName {
hasRecursiveChild = true
preloadOpt.Recursive = true
preloadOpt.RecursiveChildKey = childTable.RelatedKey
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
relationPath, childTable.RelatedKey)
break
}
}
}
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
if xfile.Recursive && basePath != "" {
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
// Still process its parent/child tables for relations like DEF
h.processXFilesRelations(xfile, options, relationPath)
return
}
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
// Recursively process nested ParentTables and ChildTables
if xfile.Recursive {
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
h.processXFilesRelations(xfile, options, relationPath)
// Skip processing child tables if we already detected and handled a recursive child
if hasRecursiveChild {
logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
// But still process parent tables
if len(xfile.ParentTables) > 0 {
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
for _, parentTable := range xfile.ParentTables {
h.addXFilesPreload(parentTable, options, relationPath)
}
}
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
h.processXFilesRelations(xfile, options, relationPath)
}

View File

@@ -2,6 +2,8 @@ package restheadspec
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestDecodeHeaderValue(t *testing.T) {
@@ -37,6 +39,131 @@ func TestDecodeHeaderValue(t *testing.T) {
}
}
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
handler := &Handler{}
options := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Preload: make([]common.PreloadOption, 0),
},
}
// Create an XFiles with SqlJoins
xfile := &XFiles{
TableName: "users",
SqlJoins: []string{
"LEFT JOIN departments d ON d.id = users.department_id",
"INNER JOIN roles r ON r.id = users.role_id",
},
FilterFields: []struct {
Field string `json:"field"`
Value string `json:"value"`
Operator string `json:"operator"`
}{
{Field: "d.active", Value: "true", Operator: "eq"},
{Field: "r.name", Value: "admin", Operator: "eq"},
},
}
// Add the XFiles preload
handler.addXFilesPreload(xfile, options, "")
// Verify that a preload was added
if len(options.Preload) != 1 {
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
}
preload := options.Preload[0]
// Verify relation name
if preload.Relation != "users" {
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
}
// Verify SqlJoins were transferred
if len(preload.SqlJoins) != 2 {
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
}
// Verify JoinAliases were extracted
if len(preload.JoinAliases) != 2 {
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
}
// Verify the aliases are correct
expectedAliases := []string{"d", "r"}
for i, expected := range expectedAliases {
if preload.JoinAliases[i] != expected {
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
}
}
// Verify filters were added
if len(preload.Filters) != 2 {
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
}
// Verify filter columns reference joined tables
if preload.Filters[0].Column != "d.active" {
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
}
if preload.Filters[1].Column != "r.name" {
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
}
}
func TestExtractJoinAlias(t *testing.T) {
tests := []struct {
name string
joinClause string
expected string
}{
{
name: "LEFT JOIN with alias",
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
expected: "d",
},
{
name: "INNER JOIN with AS keyword",
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
expected: "u",
},
{
name: "JOIN without alias",
joinClause: "JOIN roles ON roles.id = users.role_id",
expected: "",
},
{
name: "Complex join with multiple conditions",
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
expected: "p",
},
{
name: "Invalid join (no ON clause)",
joinClause: "LEFT JOIN departments",
expected: "",
},
{
name: "LATERAL join with alias",
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
expected: "fn",
},
{
name: "LATERAL join with multiline subquery containing inner ON",
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
expected: "fn",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractJoinAlias(tt.joinClause)
if result != tt.expected {
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
}
})
}
}
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
// - parseSelectFields
// - parseFieldFilter

View File

@@ -12,6 +12,10 @@ import (
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
// Use this for auth checks that need model rules and user context simultaneously.
BeforeHandle HookType = "before_handle"
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
@@ -42,6 +46,9 @@ type HookContext struct {
Model interface{}
Options ExtendedRequestOptions
// Operation being dispatched (e.g. "read", "create", "update", "delete")
Operation string
// Operation-specific fields
ID string
Data interface{} // For create/update operations
@@ -56,6 +63,14 @@ type HookContext struct {
// Response writer - allows hooks to modify response
Writer common.ResponseWriter
// Request - the original HTTP request
Request common.Request
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
// Tx provides access to the database/transaction for executing additional SQL
// This allows hooks to run custom queries in addition to the main Query chain
Tx common.Database
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
// logger.Debug("All hooks for %s executed successfully", hookType)

View File

@@ -0,0 +1,110 @@
package restheadspec
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestPreloadOption_TableName verifies that TableName field is properly used
// when provided in PreloadOption for WHERE clause processing
func TestPreloadOption_TableName(t *testing.T) {
tests := []struct {
name string
preload common.PreloadOption
expectedTable string
}{
{
name: "TableName provided explicitly",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "mastertaskitem",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "mastertaskitem",
},
{
name: "TableName empty, should use empty string",
preload: common.PreloadOption{
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
TableName: "",
Where: "rid_parentmastertaskitem is null",
},
expectedTable: "",
},
{
name: "Simple relation without nested path",
preload: common.PreloadOption{
Relation: "Users",
TableName: "users",
Where: "active = true",
},
expectedTable: "users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that the TableName field stores the correct value
if tt.preload.TableName != tt.expectedTable {
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
}
// Verify that when TableName is provided, it should be used instead of extracting from relation
tableName := tt.preload.TableName
if tableName == "" {
// This simulates the fallback logic in handler.go
// In reality, reflection.ExtractTableNameOnly would be called
tableName = tt.expectedTable
}
if tableName != tt.expectedTable {
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
}
})
}
}
// TestXFilesPreload_StoresTableName verifies that XFiles processing
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
func TestXFilesPreload_StoresTableName(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "mastertaskitem",
Prefix: "MAL",
PrimaryKey: "rid_mastertaskitem",
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
Recursive: false, // Changed from true (recursive children are now skipped)
SqlAnd: []string{"rid_parentmastertaskitem is null"},
}
options := &ExtendedRequestOptions{}
// Process XFiles
handler.addXFilesPreload(xfiles, options, "MTL")
// Verify that a preload was added
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify the table name is stored
if preload.TableName != "mastertaskitem" {
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
}
// Verify the relation path includes the prefix
expectedRelation := "MTL.MAL"
if preload.Relation != expectedRelation {
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
}
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
expectedWhere := "rid_parentmastertaskitem is null"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
}
}

View File

@@ -0,0 +1,91 @@
package restheadspec
import (
"testing"
)
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
// to WHERE clauses when SqlJoins are present
func TestPreloadWhereClause_WithJoins(t *testing.T) {
tests := []struct {
name string
where string
sqlJoins []string
expectedPrefix bool
description string
}{
{
name: "No joins - no prefix needed",
where: "status = 'active'",
sqlJoins: []string{},
expectedPrefix: false,
description: "Without JOINs, Bun knows the table context",
},
{
name: "Has joins - prefix needed",
where: "status = 'active'",
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
expectedPrefix: true,
description: "With JOINs, table prefix disambiguates columns",
},
{
name: "Already has prefix - no change",
where: "users.status = 'active'",
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
expectedPrefix: true,
description: "Existing prefix should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// This test documents the expected behavior
// The actual logic is in handler.go lines 916-937
hasJoins := len(tt.sqlJoins) > 0
if hasJoins != tt.expectedPrefix {
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
hasJoins, tt.expectedPrefix)
}
t.Logf("%s: %s", tt.name, tt.description)
})
}
}
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
// results in table prefixes being added to WHERE clauses
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
handler := &Handler{}
xfiles := &XFiles{
TableName: "users",
Prefix: "USR",
PrimaryKey: "id",
SqlAnd: []string{"status = 'active'"},
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
}
options := &ExtendedRequestOptions{}
handler.addXFilesPreload(xfiles, options, "")
if len(options.Preload) == 0 {
t.Fatal("Expected at least one preload to be added")
}
preload := options.Preload[0]
// Verify SqlJoins were stored
if len(preload.SqlJoins) != 1 {
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
}
// Verify WHERE clause does NOT have prefix yet (added later in handler)
expectedWhere := "status = 'active'"
if preload.Where != expectedWhere {
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
}
// Note: The handler will add the prefix when it sees SqlJoins
// This is tested in the handler itself, not during XFiles parsing
}

View File

@@ -301,6 +301,163 @@ func TestParseOptionsFromQueryParams(t *testing.T) {
}
},
},
{
name: "Parse custom SQL JOIN from query params",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.CustomSQLJoin) == 0 {
t.Error("Expected CustomSQLJoin to be set")
return
}
if len(options.CustomSQLJoin) != 1 {
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
return
}
expected := `LEFT JOIN departments d ON d.id = employees.department_id`
if options.CustomSQLJoin[0] != expected {
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
}
},
},
{
name: "Parse multiple custom SQL JOINs from query params",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.CustomSQLJoin) != 2 {
t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin))
return
}
expected1 := `LEFT JOIN departments d ON d.id = e.dept_id`
expected2 := `INNER JOIN roles r ON r.id = e.role_id`
if options.CustomSQLJoin[0] != expected1 {
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0])
}
if options.CustomSQLJoin[1] != expected2 {
t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1])
}
},
},
{
name: "Parse custom SQL JOIN from headers",
headers: map[string]string{
"X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.CustomSQLJoin) == 0 {
t.Error("Expected CustomSQLJoin to be set from header")
return
}
expected := `LEFT JOIN users u ON u.id = posts.user_id`
if options.CustomSQLJoin[0] != expected {
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
}
},
},
{
name: "Extract aliases from custom SQL JOIN",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.JoinAliases) == 0 {
t.Error("Expected JoinAliases to be extracted")
return
}
if len(options.JoinAliases) != 1 {
t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases))
return
}
if options.JoinAliases[0] != "d" {
t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0])
}
// Also check that it's in the embedded RequestOptions
if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" {
t.Error("Expected join alias to also be in RequestOptions.JoinAliases")
}
},
},
{
name: "Extract multiple aliases from multiple custom SQL JOINs",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.JoinAliases) != 2 {
t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases))
return
}
expectedAliases := []string{"d", "r"}
for i, expected := range expectedAliases {
if options.JoinAliases[i] != expected {
t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i])
}
}
},
},
{
name: "Custom JOIN with sort on joined table",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
"x-sort": "d.name,employees.id",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
// Verify join was added
if len(options.CustomSQLJoin) != 1 {
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
return
}
// Verify alias was extracted
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
t.Error("Expected join alias 'd' to be extracted")
return
}
// Verify sort was parsed
if len(options.Sort) != 2 {
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
return
}
if options.Sort[0].Column != "d.name" {
t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column)
}
if options.Sort[1].Column != "employees.id" {
t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column)
}
},
},
{
name: "Custom JOIN with filter on joined table",
queryParams: map[string]string{
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
"x-searchop-eq-d.name": "Engineering",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
// Verify join was added
if len(options.CustomSQLJoin) != 1 {
t.Error("Expected 1 custom SQL join")
return
}
// Verify alias was extracted
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
t.Error("Expected join alias 'd' to be extracted")
return
}
// Verify filter was parsed
if len(options.Filters) != 1 {
t.Errorf("Expected 1 filter, got %d", len(options.Filters))
return
}
if options.Filters[0].Column != "d.name" {
t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column)
}
if options.Filters[0].Operator != "eq" {
t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator)
}
},
},
}
for _, tt := range tests {
@@ -395,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) {
}
}
// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function
func TestCustomJoinAliasExtraction(t *testing.T) {
tests := []struct {
name string
join string
expected string
}{
{
name: "LEFT JOIN with alias",
join: "LEFT JOIN departments d ON d.id = employees.department_id",
expected: "d",
},
{
name: "INNER JOIN with AS keyword",
join: "INNER JOIN users AS u ON u.id = posts.user_id",
expected: "u",
},
{
name: "Simple JOIN with alias",
join: "JOIN roles r ON r.id = user_roles.role_id",
expected: "r",
},
{
name: "JOIN without alias (just table name)",
join: "JOIN departments ON departments.id = employees.dept_id",
expected: "",
},
{
name: "RIGHT JOIN with alias",
join: "RIGHT JOIN orders o ON o.customer_id = customers.id",
expected: "o",
},
{
name: "FULL OUTER JOIN with AS",
join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id",
expected: "p",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractJoinAlias(tt.join)
if result != tt.expected {
t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected)
}
})
}
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))

View File

@@ -0,0 +1,391 @@
//go:build !integration
// +build !integration
package restheadspec
import (
"context"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// TestRecursivePreloadClearsWhereClause tests that recursive preloads
// correctly clear the WHERE clause from the parent level to allow
// Bun to use foreign key relationships for loading children
func TestRecursivePreloadClearsWhereClause(t *testing.T) {
// Create a mock handler
handler := &Handler{}
// Create a preload option with a WHERE clause that filters root items
// This simulates the xfiles use case where the first level has a filter
// like "rid_parentmastertaskitem is null" to get root items
preload := common.PreloadOption{
Relation: "MastertaskItems",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
Where: "rid_parentmastertaskitem is null",
Filters: []common.FilterOption{
{
Column: "rid_parentmastertaskitem",
Operator: "is null",
Value: nil,
},
},
}
// Create a mock query that tracks operations
mockQuery := &mockSelectQuery{
operations: []string{},
}
// Apply the recursive preload at depth 0
// This should:
// 1. Apply the initial preload with the WHERE clause
// 2. Create a recursive preload without the WHERE clause
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
// Verify the mock query received the operations
mock := result.(*mockSelectQuery)
// Check that we have at least 2 PreloadRelation calls:
// 1. The initial "MastertaskItems" with WHERE clause
// 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause
preloadCount := 0
recursivePreloadFound := false
whereAppliedToRecursive := false
for _, op := range mock.operations {
if op == "PreloadRelation:MastertaskItems" {
preloadCount++
}
if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" {
recursivePreloadFound = true
}
// Check if WHERE was applied to the recursive preload (it shouldn't be)
if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound {
whereAppliedToRecursive = true
}
}
if preloadCount < 1 {
t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount)
}
if !recursivePreloadFound {
t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
}
if whereAppliedToRecursive {
t.Error("WHERE clause should not be applied to recursive preload levels")
}
}
// TestRecursivePreloadWithChildRelations tests that child relations
// (like DEF in MAL.DEF) are properly extended to recursive levels
func TestRecursivePreloadWithChildRelations(t *testing.T) {
handler := &Handler{}
// Create the main recursive preload
recursivePreload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
Where: "rid_parentmastertaskitem is null",
}
// Create a child relation that should be extended
childPreload := common.PreloadOption{
Relation: "MAL.DEF",
}
mockQuery := &mockSelectQuery{
operations: []string{},
}
allPreloads := []common.PreloadOption{recursivePreload, childPreload}
// Apply both preloads - the child preload should be extended when the recursive one processes
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0)
// Also need to apply the child preload separately (as would happen in normal flow)
result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Check that the child relation was extended to recursive levels
// We should see:
// - MAL (with WHERE)
// - MAL.DEF
// - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE)
// - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic)
foundMALDEF := false
foundRecursiveMAL := false
foundMALMALDEF := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.DEF" {
foundMALDEF = true
}
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundRecursiveMAL = true
}
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
foundMALMALDEF = true
}
}
if !foundMALDEF {
t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations)
}
if !foundRecursiveMAL {
t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
}
if !foundMALMALDEF {
t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations)
}
}
// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive
// preload generates the correct FK-based relation name using RelatedKey
func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) {
handler := &Handler{}
// Test case 1: With RelatedKey - should generate FK-based name
t.Run("WithRelatedKey", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Should generate MAL.MAL_RID_PARENTMASTERTASKITEM
foundCorrectRelation := false
foundIncorrectRelation := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundCorrectRelation = true
}
if op == "PreloadRelation:MAL.MAL" {
foundIncorrectRelation = true
}
}
if !foundCorrectRelation {
t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations)
}
if foundIncorrectRelation {
t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified")
}
})
// Test case 2: Without RelatedKey - should fallback to old behavior
t.Run("WithoutRelatedKey", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
// No RelatedKey
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
mock := result.(*mockSelectQuery)
// Should fallback to MAL.MAL
foundFallback := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL" {
foundFallback = true
}
}
if !foundFallback {
t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations)
}
})
// Test case 3: Depth limit of 8
t.Run("DepthLimit", func(t *testing.T) {
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
mockQuery := &mockSelectQuery{operations: []string{}}
allPreloads := []common.PreloadOption{preload}
// Start at depth 7 - should create one more level
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
mock := result.(*mockSelectQuery)
foundDepth8 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth8 = true
}
}
if !foundDepth8 {
t.Error("Expected to create recursive level at depth 8")
}
// Start at depth 8 - should NOT create another level
mockQuery2 := &mockSelectQuery{operations: []string{}}
result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8)
mock2 := result2.(*mockSelectQuery)
foundDepth9 := false
for _, op := range mock2.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth9 = true
}
}
if foundDepth9 {
t.Error("Should NOT create recursive level beyond depth 8")
}
})
}
// mockSelectQuery implements common.SelectQuery for testing
type mockSelectQuery struct {
operations []string
}
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
m.operations = append(m.operations, "Model")
return m
}
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
m.operations = append(m.operations, "Table:"+table)
return m
}
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
for _, col := range columns {
m.operations = append(m.operations, "Column:"+col)
}
return m
}
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "ColumnExpr:"+query)
return m
}
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Where:"+query)
return m
}
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereOr:"+query)
return m
}
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereIn:"+column)
return m
}
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
m.operations = append(m.operations, "Order:"+order)
return m
}
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "OrderExpr:"+order)
return m
}
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
m.operations = append(m.operations, "Limit")
return m
}
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
m.operations = append(m.operations, "Offset")
return m
}
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Join:"+join)
return m
}
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "LeftJoin:"+join)
return m
}
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
m.operations = append(m.operations, "Group")
return m
}
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Having:"+query)
return m
}
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Preload:"+relation)
return m
}
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "PreloadRelation:"+relation)
// Apply the preload modifiers
for _, fn := range apply {
fn(m)
}
return m
}
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "JoinRelation:"+relation)
return m
}
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
m.operations = append(m.operations, "Scan")
return nil
}
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
m.operations = append(m.operations, "ScanModel")
return nil
}
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
m.operations = append(m.operations, "Count")
return 0, nil
}
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
m.operations = append(m.operations, "Exists")
return false, nil
}
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
return nil
}
func (m *mockSelectQuery) GetModel() interface{} {
return nil
}

View File

@@ -32,6 +32,7 @@
// - X-Clean-JSON: Boolean to remove null/empty fields
// - X-Custom-SQL-Where: Custom SQL WHERE clause (AND)
// - X-Custom-SQL-Or: Custom SQL WHERE clause (OR)
// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple)
//
// # Usage Example
//
@@ -103,8 +104,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
handler.HandleOpenAPI(respAdapter, reqAdapter)
})
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
@@ -123,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
metadataPath := buildRoutePath(schema, entity) + "/metadata"
// Create handler functions for this specific entity
entityHandler := createMuxHandler(handler, schema, entity, "")
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
entityHandler = authMiddleware(entityHandler)
entityWithIDHandler = authMiddleware(entityWithIDHandler)
metadataHandler = authMiddleware(metadataHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
@@ -161,7 +163,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
@@ -169,7 +172,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.Handle(respAdapter, reqAdapter, vars)
}
}
@@ -180,7 +183,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
@@ -188,7 +192,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
@@ -200,13 +204,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
corsConfig := common.DefaultCORSConfig()
corsConfig.AllowedMethods = allowedMethods
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(r)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
// Return metadata in the OPTIONS response body
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
@@ -275,9 +280,34 @@ type BunRouterHandler interface {
Handle(method, path string, handler bunrouter.HandlerFunc)
}
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
if authMiddleware == nil {
return handler
}
return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
})
// Wrap with auth middleware and execute
wrappedHandler := authMiddleware(httpHandler)
wrappedHandler.ServeHTTP(w, req.Request)
return nil
}
}
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
// Accepts bunrouter.Router or bunrouter.Group
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
// authMiddleware is optional - if provided, routes will be protected with the middleware
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
// CORS config
corsConfig := common.DefaultCORSConfig()
@@ -285,15 +315,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
// Add global /openapi route
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
handler.HandleOpenAPI(respAdapter, reqAdapter)
return nil
})
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
return nil
})
@@ -315,135 +346,155 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
currentEntity := entity
// GET and POST for /{schema}/{entity}
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
handler.Handle(respAdapter, reqAdapter, params)
return nil
}
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
handler.Handle(respAdapter, reqAdapter, params)
return nil
}
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
handler.Handle(respAdapter, reqAdapter, params)
return nil
}
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
// Metadata endpoint
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewBunRouterRequest(req)
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
// OPTIONS route without ID (returns metadata)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewBunRouterRequest(req)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route with ID (returns metadata)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
reqAdapter := router.NewBunRouterRequest(req)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
@@ -458,8 +509,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
// Create bunrouter
bunRouter := bunrouter.New()
// Setup routes
SetupBunRouterRoutes(bunRouter, handler)
// Setup routes without authentication
SetupBunRouterRoutes(bunRouter, handler, nil)
// Start server
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
@@ -479,7 +530,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
apiGroup := bunRouter.NewGroup("/api")
// Setup RestHeadSpec routes on the group - routes will be under /api
SetupBunRouterRoutes(apiGroup, handler)
SetupBunRouterRoutes(apiGroup, handler, nil)
// Start server
if err := http.ListenAndServe(":8080", bunRouter); err != nil {

View File

@@ -2,6 +2,7 @@ package restheadspec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
@@ -9,6 +10,17 @@ import (
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
return security.LogDataAccess(secCtx)
})
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelUpdateAllowed(secCtx)
})
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelDeleteAllowed(secCtx)
})
logger.Info("Security hooks registered for restheadspec handler")
}

View File

@@ -0,0 +1,527 @@
//go:build integration
// +build integration
package restheadspec
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockSelectQuery implements common.SelectQuery for testing (integration version)
type mockSelectQuery struct {
operations []string
}
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
m.operations = append(m.operations, "Model")
return m
}
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
m.operations = append(m.operations, "Table:"+table)
return m
}
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
for _, col := range columns {
m.operations = append(m.operations, "Column:"+col)
}
return m
}
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "ColumnExpr:"+query)
return m
}
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Where:"+query)
return m
}
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereOr:"+query)
return m
}
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
m.operations = append(m.operations, "WhereIn:"+column)
return m
}
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
m.operations = append(m.operations, "Order:"+order)
return m
}
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "OrderExpr:"+order)
return m
}
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
m.operations = append(m.operations, "Limit")
return m
}
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
m.operations = append(m.operations, "Offset")
return m
}
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Join:"+join)
return m
}
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "LeftJoin:"+join)
return m
}
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
m.operations = append(m.operations, "Group")
return m
}
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Having:"+query)
return m
}
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
m.operations = append(m.operations, "Preload:"+relation)
return m
}
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "PreloadRelation:"+relation)
// Apply the preload modifiers
for _, fn := range apply {
fn(m)
}
return m
}
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
m.operations = append(m.operations, "JoinRelation:"+relation)
return m
}
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
m.operations = append(m.operations, "Scan")
return nil
}
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
m.operations = append(m.operations, "ScanModel")
return nil
}
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
m.operations = append(m.operations, "Count")
return 0, nil
}
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
m.operations = append(m.operations, "Exists")
return false, nil
}
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
return nil
}
func (m *mockSelectQuery) GetModel() interface{} {
return nil
}
// TestXFilesRecursivePreload is an integration test that validates the XFiles
// recursive preload functionality using real test data files.
//
// This test ensures:
// 1. XFiles request JSON is correctly parsed into PreloadOptions
// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM)
// 3. Parent WHERE clauses don't leak to child levels
// 4. Child relations (like DEF) are extended to all recursive levels
// 5. Hierarchical data structure matches expected output
func TestXFilesRecursivePreload(t *testing.T) {
// Load the XFiles request configuration
requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json")
requestData, err := os.ReadFile(requestPath)
require.NoError(t, err, "Failed to read xfiles.request.json")
var xfileConfig XFiles
err = json.Unmarshal(requestData, &xfileConfig)
require.NoError(t, err, "Failed to parse xfiles.request.json")
// Create handler and parse XFiles into PreloadOptions
handler := &Handler{}
options := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Preload: []common.PreloadOption{},
},
}
// Process the XFiles configuration - start with the root table
handler.processXFilesRelations(&xfileConfig, options, "")
// Verify that preload options were created
require.NotEmpty(t, options.Preload, "Expected preload options to be created")
// Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
break
}
}
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
// RelatedKey should be the parent relationship key (MTL -> MAL)
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
"Recursive preload should preserve original RelatedKey for parent relationship")
// RecursiveChildKey should be set from the recursive child config
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
"Recursive preload should have RecursiveChildKey set from recursive child config")
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
})
// Test 2: Verify mastertaskitem has WHERE clause for filtering root items
t.Run("RootLevelHasWhereClause", func(t *testing.T) {
var rootPreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL" {
rootPreload = preload
break
}
}
require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
})
// Test 3: Verify actiondefinition relation exists for mastertaskitem
t.Run("DEFRelationExists", func(t *testing.T) {
var defPreload *common.PreloadOption
for i := range options.Preload {
preload := &options.Preload[i]
if preload.Relation == "MTL.MAL.DEF" {
defPreload = preload
break
}
}
require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem")
assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey,
"actiondefinition preload should have ForeignKey set")
})
// Test 4: Verify relation name generation with mock query
t.Run("RelationNameGeneration", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
found := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
found = true
break
}
}
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// Create mock query to track operations
mockQuery := &mockSelectQuery{operations: []string{}}
// Apply the recursive preload
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// Verify the correct FK-based relation name was generated
foundCorrectRelation := false
for _, op := range mock.operations {
// Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundCorrectRelation = true
}
}
assert.True(t, foundCorrectRelation,
"Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
mock.operations)
})
// Test 5: Verify WHERE clause is cleared for recursive levels
t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
found := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
found = true
break
}
}
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
// The root level has a WHERE clause (rid_parentmastertaskitem is null)
// But when we apply recursion, it should be cleared
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// After the first level, WHERE clauses should not be reapplied
// We check that the recursive relation was created (which means WHERE was cleared internally)
foundRecursiveRelation := false
for _, op := range mock.operations {
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundRecursiveRelation = true
}
}
assert.True(t, foundRecursiveRelation,
"Recursive relation should be created (WHERE clause should be cleared internally)")
})
// Test 6: Verify child relations are extended to recursive levels
t.Run("ChildRelationsExtended", func(t *testing.T) {
// Find the mastertaskitem preload - it should be marked as recursive
var recursivePreload common.PreloadOption
foundRecursive := false
for _, preload := range options.Preload {
if preload.Relation == "MTL.MAL" && preload.Recursive {
recursivePreload = preload
foundRecursive = true
break
}
}
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
mock := result.(*mockSelectQuery)
// actiondefinition should be extended to the recursive level
// Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
foundExtendedDEF := false
for _, op := range mock.operations {
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
foundExtendedDEF = true
}
}
assert.True(t, foundExtendedDEF,
"Expected actiondefinition relation to be extended to recursive level. Operations: %v",
mock.operations)
})
}
// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8
func TestXFilesRecursivePreloadDepth(t *testing.T) {
handler := &Handler{}
preload := common.PreloadOption{
Relation: "MAL",
Recursive: true,
RelatedKey: "rid_parentmastertaskitem",
}
allPreloads := []common.PreloadOption{preload}
t.Run("Depth7CreatesLevel8", func(t *testing.T) {
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
mock := result.(*mockSelectQuery)
foundDepth8 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth8 = true
}
}
assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7")
})
t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) {
mockQuery := &mockSelectQuery{operations: []string{}}
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8)
mock := result.(*mockSelectQuery)
foundDepth9 := false
for _, op := range mock.operations {
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
foundDepth9 = true
}
}
assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)")
})
}
// TestXFilesResponseStructure validates the actual structure of the response
// This test can be expanded when we have a full database integration test environment
func TestXFilesResponseStructure(t *testing.T) {
// Load the expected correct response
correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json")
correctData, err := os.ReadFile(correctResponsePath)
require.NoError(t, err, "Failed to read xfiles.response.correct.json")
var correctResponse []map[string]interface{}
err = json.Unmarshal(correctData, &correctResponse)
require.NoError(t, err, "Failed to parse xfiles.response.correct.json")
// Test 1: Verify root level has exactly 1 masterprocess
t.Run("RootLevelHasOneItem", func(t *testing.T) {
assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record")
})
// Test 2: Verify the root item has MTL relation
t.Run("RootHasMTLRelation", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, exists := rootItem["MTL"]
assert.True(t, exists, "Root item should have MTL relation")
assert.NotNil(t, mtl, "MTL relation should not be null")
})
// Test 3: Verify MTL has MAL items
t.Run("MTLHasMALItems", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, exists := firstMTL["MAL"]
assert.True(t, exists, "MTL item should have MAL relation")
assert.NotNil(t, mal, "MAL relation should not be null")
})
// Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive)
t.Run("MALHasRecursiveRelation", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
require.NotEmpty(t, mal, "MAL should have items")
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
// The key assertion: check for FK-based relation name
recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"]
assert.True(t, exists,
"MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)")
// It can be null or an array, depending on whether this item has children
if recursiveRelation != nil {
_, isArray := recursiveRelation.([]interface{})
assert.True(t, isArray,
"MAL_RID_PARENTMASTERTASKITEM should be an array when not null")
}
})
// Test 5: Verify "Receive COB Document for" appears as a child, not at root
t.Run("ChildItemsAreNested", func(t *testing.T) {
// This test verifies that "Receive COB Document for" doesn't appear
// multiple times at the wrong level, but is properly nested
// Count how many times we find this description at the MAL level (should be 0 or 1)
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
// Count root-level MAL items (before the fix, there were 12; should be 1)
assert.Len(t, mal, 1,
"MAL should have exactly 1 root-level item (before fix: 12 duplicates)")
// Verify the root item has a description
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
description, exists := firstMAL["description"]
assert.True(t, exists, "MAL item should have a description")
assert.Equal(t, "Capture COB Information", description,
"Root MAL item should be 'Capture COB Information'")
})
// Test 6: Verify DEF relation exists at MAL level
t.Run("DEFRelationExists", func(t *testing.T) {
require.NotEmpty(t, correctResponse, "Response should not be empty")
rootItem := correctResponse[0]
mtl, ok := rootItem["MTL"].([]interface{})
require.True(t, ok, "MTL should be an array")
require.NotEmpty(t, mtl, "MTL should have items")
firstMTL, ok := mtl[0].(map[string]interface{})
require.True(t, ok, "MTL item should be a map")
mal, ok := firstMTL["MAL"].([]interface{})
require.True(t, ok, "MAL should be an array")
require.NotEmpty(t, mal, "MAL should have items")
firstMAL, ok := mal[0].(map[string]interface{})
require.True(t, ok, "MAL item should be a map")
// Verify DEF relation exists (child relation extension)
def, exists := firstMAL["DEF"]
assert.True(t, exists, "MAL item should have DEF relation")
// DEF can be null or an object
if def != nil {
_, isMap := def.(map[string]interface{})
assert.True(t, isMap, "DEF should be an object when not null")
}
})
}

153
pkg/security/KEYSTORE.md Normal file
View File

@@ -0,0 +1,153 @@
# Keystore
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
## Key types
| Constant | Value | Use case |
|---|---|---|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
## Storage backends
### ConfigKeyStore
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
```go
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
store := security.NewConfigKeyStore([]security.UserKey{
{
UserID: 1,
KeyType: security.KeyTypeGenericAPI,
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
Name: "CI deploy",
Scopes: []string{"deploy"},
IsActive: true,
},
})
```
### DatabaseKeyStore
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
```go
db, _ := sql.Open("postgres", dsn)
store := security.NewDatabaseKeyStore(db)
// With options
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
CacheTTL: 5 * time.Minute,
SQLNames: &security.KeyStoreSQLNames{
ValidateKey: "myapp_keystore_validate", // override one procedure name
},
})
```
## Managing keys
```go
ctx := context.Background()
// Create — raw key returned once; store it securely
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
UserID: 42,
KeyType: security.KeyTypeGenericAPI,
Name: "mobile app",
Scopes: []string{"read", "write"},
})
fmt.Println(resp.RawKey) // only shown here; hashed internally
// List
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
// Revoke
err = store.DeleteKey(ctx, 42, resp.Key.ID)
// Validate (used by authenticators internally)
key, err := store.ValidateKey(ctx, rawKey, "")
```
## HTTP authentication
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
Keys are extracted from the request in this order:
1. `Authorization: Bearer <key>`
2. `Authorization: ApiKey <key>`
3. `X-API-Key: <key>`
```go
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
// Restrict to a specific type:
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
```
Plug it into a handler:
```go
handler := resolvespec.NewHandler(db, registry,
resolvespec.WithAuthenticator(auth),
)
```
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
On successful validation the request context receives a `UserContext` where:
- `UserID` — from the key
- `Roles` — the key's `Scopes`
- `Claims["key_type"]` — key type string
- `Claims["key_name"]` — key name
## Database setup
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
```sql
\i pkg/security/keystore_schema.sql
```
This creates:
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
- `resolvespec_keystore_create_key(p_request jsonb)`
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
### Custom procedure names
```go
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
SQLNames: &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
CreateKey: "myschema_create_key",
DeleteKey: "myschema_delete_key",
ValidateKey: "myschema_validate_key",
},
})
// Validate names at startup
names := &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
// ...
}
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
log.Fatal(err)
}
```
## Security notes
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.

527
pkg/security/OAUTH2.md Normal file
View File

@@ -0,0 +1,527 @@
# OAuth2 Authentication Guide
## Overview
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
## Features
- **Universal OAuth2 Support**: Works with any OAuth2 provider
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
- **Custom Providers**: Easy configuration for any OAuth2 service
- **Session Management**: Database-backed session storage
- **Token Refresh**: Automatic token refresh support
- **State Validation**: Built-in CSRF protection
- **User Auto-Creation**: Automatically creates users on first login
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
## Quick Start
### 1. Database Setup
```sql
-- Run the schema from database_schema.sql
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255),
user_level INTEGER DEFAULT 0,
roles VARCHAR(500),
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP,
remote_id VARCHAR(255),
auth_provider VARCHAR(50)
);
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
auth_provider VARCHAR(50)
);
-- OAuth2 stored procedures (7 functions)
-- See database_schema.sql for full implementation
```
### 2. Google OAuth2
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
// Create authenticator
oauth2Auth := security.NewGoogleAuthenticator(
"your-google-client-id",
"your-google-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
// Login route - redirects to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback route - handles Google response
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
```
### 3. GitHub OAuth2
```go
oauth2Auth := security.NewGitHubAuthenticator(
"your-github-client-id",
"your-github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
// Same routes pattern as Google
router.HandleFunc("/auth/github/login", ...)
router.HandleFunc("/auth/github/callback", ...)
```
### 4. Microsoft OAuth2
```go
oauth2Auth := security.NewMicrosoftAuthenticator(
"your-microsoft-client-id",
"your-microsoft-client-secret",
"http://localhost:8080/auth/microsoft/callback",
db,
)
```
### 5. Facebook OAuth2
```go
oauth2Auth := security.NewFacebookAuthenticator(
"your-facebook-client-id",
"your-facebook-client-secret",
"http://localhost:8080/auth/facebook/callback",
db,
)
```
## Custom OAuth2 Provider
```go
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
ClientID: "your-client-id",
ClientSecret: "your-client-secret",
RedirectURL: "http://localhost:8080/auth/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://your-provider.com/oauth/authorize",
TokenURL: "https://your-provider.com/oauth/token",
UserInfoURL: "https://your-provider.com/oauth/userinfo",
DB: db,
ProviderName: "custom",
// Optional: Custom user info parser
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
return &security.UserContext{
UserName: userInfo["username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["id"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo,
}, nil
},
})
```
## Protected Routes
```go
// Create security provider
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
// Apply middleware to protected routes
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(security.NewAuthMiddleware(securityList))
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
json.NewEncoder(w).Encode(userCtx)
})
```
## Token Refresh
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
```go
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google", "github", etc.
}
json.NewDecoder(r.Body).Decode(&req)
// Default to google if not specified
if req.Provider == "" {
req.Provider = "google"
}
// Use OAuth2-specific refresh method
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
json.NewEncoder(w).Encode(loginResp)
})
```
**Important Notes:**
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
- Store the refresh token securely on the client side
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
## Logout
```go
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
Token: userCtx.SessionID,
UserID: userCtx.UserID,
})
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
MaxAge: -1,
})
w.WriteHeader(http.StatusOK)
})
```
## Multi-Provider Setup
```go
// Single DatabaseAuthenticator with ALL OAuth2 providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(security.OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
})
// Get list of configured providers
providers := auth.OAuth2GetProviders() // ["google", "github"]
// Google routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
// ... handle response
})
// GitHub routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
// ... handle response
})
// Use same authenticator for protected routes - works for ALL providers
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
```
## Configuration Options
### OAuth2Config Fields
| Field | Type | Description |
|-------|------|-------------|
| ClientID | string | OAuth2 client ID from provider |
| ClientSecret | string | OAuth2 client secret |
| RedirectURL | string | Callback URL registered with provider |
| Scopes | []string | OAuth2 scopes to request |
| AuthURL | string | Provider's authorization endpoint |
| TokenURL | string | Provider's token endpoint |
| UserInfoURL | string | Provider's user info endpoint |
| DB | *sql.DB | Database connection for sessions |
| UserInfoParser | func | Custom parser for user info (optional) |
| StateValidator | func | Custom state validator (optional) |
| ProviderName | string | Provider name for logging (optional) |
## User Info Parsing
The default parser extracts these standard fields:
- `sub` → RemoteID
- `email` → Email, UserName
- `name` → UserName
- `login` → UserName (GitHub)
Custom parser example:
```go
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
// Extract custom fields
ctx := &security.UserContext{
UserName: userInfo["preferred_username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["sub"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo, // Store all claims
}
// Add custom roles based on provider data
if groups, ok := userInfo["groups"].([]interface{}); ok {
for _, g := range groups {
ctx.Roles = append(ctx.Roles, g.(string))
}
}
return ctx, nil
}
```
## Security Best Practices
1. **Always use HTTPS in production**
```go
http.SetCookie(w, &http.Cookie{
Secure: true, // Only send over HTTPS
HttpOnly: true, // Prevent XSS access
SameSite: http.SameSiteLaxMode, // CSRF protection
})
```
2. **Store secrets securely**
```go
clientID := os.Getenv("GOOGLE_CLIENT_ID")
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
```
3. **Validate redirect URLs**
- Only register trusted redirect URLs with OAuth2 providers
- Never accept redirect URL from request parameters
5. **Session expiration**
- OAuth2 sessions automatically expire based on token expiry
- Clean up expired sessions periodically:
```sql
DELETE FROM user_sessions WHERE expires_at < NOW();
```
4. **State parameter**
- Automatically generated with cryptographic randomness
- One-time use and expires after 10 minutes
- Prevents CSRF attacks
## Implementation Details
All database operations use stored procedures for consistency and security:
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
- `resolvespec_oauth_createsession` - Create OAuth2 session
- `resolvespec_oauth_getsession` - Validate and retrieve session
- `resolvespec_oauth_deletesession` - Logout/delete session
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
- `resolvespec_oauth_getuser` - Get user data by ID
## Provider Setup Guides
### Google
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create a new project or select existing
3. Enable Google+ API
4. Create OAuth 2.0 credentials
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
6. Copy Client ID and Client Secret
### GitHub
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
2. Click "New OAuth App"
3. Set Homepage URL: `http://localhost:8080`
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
5. Copy Client ID and Client Secret
### Microsoft
1. Go to [Azure Portal](https://portal.azure.com/)
2. Register new application in Azure AD
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
4. Create client secret
5. Copy Application (client) ID and secret value
### Facebook
1. Go to [Facebook Developers](https://developers.facebook.com/)
2. Create new app
3. Add Facebook Login product
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
5. Copy App ID and App Secret
## Troubleshooting
### "redirect_uri_mismatch" error
- Ensure the redirect URL in code matches exactly with provider configuration
- Include protocol (http/https), domain, port, and path
### "invalid_client" error
- Verify Client ID and Client Secret are correct
- Check if credentials are for the correct environment (dev/prod)
### "invalid_grant" error during token exchange
- State parameter validation failed
- Token might have expired
- Check server time synchronization
### User not created after successful OAuth2 login
- Check database constraints (username/email unique)
- Verify UserInfoParser is extracting required fields
- Check database logs for constraint violations
## Testing
```go
func TestOAuth2Flow(t *testing.T) {
// Mock database
db, mock, _ := sqlmock.New()
oauth2Auth := security.NewGoogleAuthenticator(
"test-client-id",
"test-client-secret",
"http://localhost/callback",
db,
)
// Test state generation
state, err := oauth2Auth.GenerateState()
assert.NoError(t, err)
assert.NotEmpty(t, state)
// Test auth URL generation
authURL := oauth2Auth.GetAuthURL(state)
assert.Contains(t, authURL, "accounts.google.com")
assert.Contains(t, authURL, state)
}
```
## API Reference
### DatabaseAuthenticator with OAuth2
| Method | Description |
|--------|-------------|
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
| OAuth2GenerateState() | Generates random state for CSRF protection |
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
| Login(ctx, req) | Standard username/password login |
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
### Pre-configured Constructors
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
## Examples
Complete working examples available in `oauth2_examples.go`:
- Basic Google OAuth2
- GitHub OAuth2
- Custom provider
- Multi-provider setup
- Token refresh
- Logout flow
- Complete integration with security middleware

View File

@@ -0,0 +1,281 @@
# OAuth2 Refresh Token - Quick Reference
## Quick Setup (3 Steps)
### 1. Initialize Authenticator
```go
auth := security.NewGoogleAuthenticator(
"client-id",
"client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
```
### 2. OAuth2 Login Flow
```go
// Login - Redirect to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback - Store tokens
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, _ := auth.OAuth2HandleCallback(
r.Context(),
"google",
r.URL.Query().Get("code"),
r.URL.Query().Get("state"),
)
// Save refresh_token on client
// loginResp.RefreshToken - Store this securely!
// loginResp.Token - Session token for API calls
})
```
### 3. Refresh Endpoint
```go
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh token
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
if err != nil {
http.Error(w, err.Error(), 401)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
---
## Multi-Provider Example
```go
// Configure multiple providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ProviderName: "google",
ClientID: "google-client-id",
ClientSecret: "google-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
}).
WithOAuth2(security.OAuth2Config{
ProviderName: "github",
ClientID: "github-client-id",
ClientSecret: "github-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
})
// Refresh with provider selection
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google" or "github"
}
json.NewDecoder(r.Body).Decode(&req)
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), 401)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
---
## Client-Side JavaScript
```javascript
// Automatic token refresh on 401
async function apiCall(url) {
let response = await fetch(url, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
// Token expired - refresh it
if (response.status === 401) {
await refreshToken();
// Retry request with new token
response = await fetch(url, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
}
return response.json();
}
async function refreshToken() {
const response = await fetch('/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
refresh_token: localStorage.getItem('refresh_token'),
provider: localStorage.getItem('provider')
})
});
if (response.ok) {
const data = await response.json();
localStorage.setItem('access_token', data.token);
localStorage.setItem('refresh_token', data.refresh_token);
} else {
// Refresh failed - redirect to login
window.location.href = '/login';
}
}
```
---
## API Methods
| Method | Parameters | Returns |
|--------|-----------|---------|
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
| `OAuth2GenerateState` | none | `string, error` |
| `OAuth2GetProviders` | none | `[]string` |
---
## LoginResponse Structure
```go
type LoginResponse struct {
Token string // New session token for API calls
RefreshToken string // Refresh token (store securely)
User *UserContext // User information
ExpiresIn int64 // Seconds until token expires
}
```
---
## Database Stored Procedures
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
- `resolvespec_oauth_getuser(user_id)` - Get user data
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
---
## Common Errors
| Error | Cause | Solution |
|-------|-------|----------|
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
---
## Security Checklist
- [ ] Use HTTPS for all OAuth2 endpoints
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
- [ ] Implement rate limiting on refresh endpoint
- [ ] Log refresh attempts for audit
- [ ] Rotate tokens on refresh
- [ ] Revoke old sessions after successful refresh
---
## Testing
```bash
# 1. Login and get refresh token
curl http://localhost:8080/auth/google/login
# Follow OAuth2 flow, get refresh_token from callback response
# 2. Refresh token
curl -X POST http://localhost:8080/auth/refresh \
-H "Content-Type: application/json" \
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
# 3. Use new token
curl http://localhost:8080/api/protected \
-H "Authorization: Bearer sess_abc123..."
```
---
## Pre-configured Providers
```go
// Google
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
// GitHub
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
// Microsoft
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
// Facebook
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
// All providers at once
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
"google": {...},
"github": {...},
})
```
---
## Provider-Specific Notes
### Google
- Add `access_type=offline` to get refresh token
- Add `prompt=consent` to force consent screen
```go
authURL += "&access_type=offline&prompt=consent"
```
### GitHub
- Refresh tokens not always provided
- May need to request `offline_access` scope
### Microsoft
- Use `offline_access` scope for refresh token
### Facebook
- Tokens expire after 60 days by default
- Check app settings for token expiration policy
---
## Complete Example
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.

View File

@@ -0,0 +1,495 @@
# OAuth2 Refresh Token Implementation
## Overview
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
## Implementation Status: ✅ COMPLETE
### Components Implemented
1. **✅ Database Schema** - Tables and stored procedures
2. **✅ Go Methods** - OAuth2RefreshToken implementation
3. **✅ Thread Safety** - Mutex protection for provider map
4. **✅ Examples** - Working code examples
5. **✅ Documentation** - Complete API reference
---
## 1. Database Schema
### Tables Modified
```sql
-- user_sessions table with OAuth2 token fields
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT, -- OAuth2 access token
refresh_token TEXT, -- OAuth2 refresh token
token_type VARCHAR(50), -- "Bearer", etc.
auth_provider VARCHAR(50) -- "google", "github", etc.
);
```
### Stored Procedures
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
- Gets OAuth2 session data by refresh token
- Returns: `{user_id, access_token, token_type, expiry}`
- Location: `database_schema.sql:714`
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
- Updates session with new tokens after refresh
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
- Location: `database_schema.sql:752`
**`resolvespec_oauth_getuser(p_user_id)`**
- Gets user data by ID for building UserContext
- Location: `database_schema.sql:791`
---
## 2. Go Implementation
### Method Signature
```go
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
ctx context.Context,
refreshToken string,
providerName string,
) (*LoginResponse, error)
```
**Location:** `pkg/security/oauth2_methods.go:375`
### Implementation Flow
```
1. Validate provider exists
├─ getOAuth2Provider(providerName) with RLock
└─ Return error if provider not configured
2. Get session from database
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
└─ Parse session data {user_id, access_token, token_type, expiry}
3. Refresh token with OAuth2 provider
├─ Create oauth2.Token from stored data
├─ Use provider.config.TokenSource(ctx, oldToken)
└─ Call tokenSource.Token() to get new token
4. Generate new session token
└─ Use OAuth2GenerateState() for secure random token
5. Update database
├─ Call resolvespec_oauth_updaterefreshtoken()
└─ Store new session_token, access_token, refresh_token
6. Get user data
├─ Call resolvespec_oauth_getuser(user_id)
└─ Build UserContext
7. Return LoginResponse
└─ {Token, RefreshToken, User, ExpiresIn}
```
### Thread Safety
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
```go
type DatabaseAuthenticator struct {
oauth2Providers map[string]*OAuth2Provider
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
}
// Read operations use RLock
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
a.oauth2ProvidersMutex.RLock()
defer a.oauth2ProvidersMutex.RUnlock()
// ... access map
}
// Write operations use Lock
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
a.oauth2ProvidersMutex.Lock()
defer a.oauth2ProvidersMutex.Unlock()
// ... modify map
}
```
---
## 3. Usage Examples
### Single Provider (Google)
```go
package main
import (
"database/sql"
"encoding/json"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
"github.com/gorilla/mux"
)
func main() {
db, _ := sql.Open("postgres", "connection-string")
// Create Google OAuth2 authenticator
auth := security.NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Token refresh endpoint
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh token (provider name defaults to "google")
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
})
json.NewEncoder(w).Encode(loginResp)
})
http.ListenAndServe(":8080", router)
}
```
### Multi-Provider Setup
```go
// Single authenticator with multiple OAuth2 providers
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(security.OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
})
// Refresh endpoint with provider selection
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google" or "github"
}
json.NewDecoder(r.Body).Decode(&req)
// Refresh with specific provider
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(loginResp)
})
```
### Client-Side Usage
```javascript
// JavaScript client example
async function refreshAccessToken() {
const refreshToken = localStorage.getItem('refresh_token');
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
const response = await fetch('/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
refresh_token: refreshToken,
provider: provider
})
});
if (response.ok) {
const data = await response.json();
// Store new tokens
localStorage.setItem('access_token', data.token);
localStorage.setItem('refresh_token', data.refresh_token);
console.log('Token refreshed successfully');
return data.token;
} else {
// Refresh failed - redirect to login
window.location.href = '/login';
}
}
// Automatically refresh token when API returns 401
async function apiCall(endpoint) {
let response = await fetch(endpoint, {
headers: {
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
}
});
if (response.status === 401) {
// Token expired - try refresh
const newToken = await refreshAccessToken();
// Retry with new token
response = await fetch(endpoint, {
headers: {
'Authorization': 'Bearer ' + newToken
}
});
}
return response.json();
}
```
---
## 4. API Reference
### DatabaseAuthenticator Methods
| Method | Signature | Description |
|--------|-----------|-------------|
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
### LoginResponse Structure
```go
type LoginResponse struct {
Token string // New session token
RefreshToken string // New refresh token (may be same as input)
User *UserContext // User information
ExpiresIn int64 // Seconds until expiration
}
type UserContext struct {
UserID int // Database user ID
UserName string // Username
Email string // Email address
UserLevel int // Permission level
SessionID string // Session token
RemoteID string // OAuth2 provider user ID
Roles []string // User roles
Claims map[string]any // Additional claims
}
```
---
## 5. Important Notes
### Provider Configuration
**For Google:** Add `access_type=offline` to get refresh token on first login:
```go
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
// When generating auth URL, add access_type parameter
authURL, _ := auth.OAuth2GetAuthURL("google", state)
authURL += "&access_type=offline&prompt=consent"
```
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
### Token Storage
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
- Never log refresh tokens
- Refresh tokens are long-lived (days/months depending on provider)
- Access tokens are short-lived (minutes/hours)
### Error Handling
Common errors:
- `"invalid or expired refresh token"` - Token expired or revoked
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
- `"failed to refresh token with provider"` - Provider rejected refresh request
### Security Best Practices
1. **Always use HTTPS** for token transmission
2. **Store refresh tokens securely** on client
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
4. **Implement token rotation** - issue new refresh token on each refresh
5. **Revoke old tokens** after successful refresh
6. **Rate limit** refresh endpoints
7. **Log refresh attempts** for audit trail
---
## 6. Testing
### Manual Test Flow
1. **Initial Login:**
```bash
curl http://localhost:8080/auth/google/login
# Follow redirect to Google
# Returns to callback with LoginResponse containing refresh_token
```
2. **Wait for Token Expiry (or manually expire in DB)**
3. **Refresh Token:**
```bash
curl -X POST http://localhost:8080/auth/refresh \
-H "Content-Type: application/json" \
-d '{
"refresh_token": "ya29.a0AfH6SMB...",
"provider": "google"
}'
# Response:
{
"token": "sess_abc123...",
"refresh_token": "ya29.a0AfH6SMB...",
"user": {
"user_id": 1,
"user_name": "john_doe",
"email": "john@example.com",
"session_id": "sess_abc123..."
},
"expires_in": 3600
}
```
4. **Use New Token:**
```bash
curl http://localhost:8080/api/protected \
-H "Authorization: Bearer sess_abc123..."
```
### Database Verification
```sql
-- Check session with refresh token
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
FROM user_sessions
WHERE refresh_token = 'ya29.a0AfH6SMB...';
-- Verify token was updated after refresh
SELECT session_token, access_token, refresh_token,
expires_at, last_activity_at
FROM user_sessions
WHERE user_id = 1
ORDER BY created_at DESC
LIMIT 1;
```
---
## 7. Troubleshooting
### "Refresh token not found or expired"
**Cause:** Refresh token doesn't exist in database or session expired
**Solution:**
- Check if initial OAuth2 login stored refresh token
- Verify provider returns refresh token (some require `access_type=offline`)
- Check session hasn't been deleted from database
### "Failed to refresh token with provider"
**Cause:** OAuth2 provider rejected the refresh request
**Possible reasons:**
- Refresh token was revoked by user
- OAuth2 app credentials changed
- Network connectivity issues
- Provider rate limiting
**Solution:**
- Re-authenticate user (full OAuth2 flow)
- Check provider dashboard for app status
- Verify client credentials are correct
### "OAuth2 provider 'xxx' not found"
**Cause:** Provider not registered with `WithOAuth2()`
**Solution:**
```go
// Make sure provider is configured
auth := security.NewDatabaseAuthenticator(db).
WithOAuth2(security.OAuth2Config{
ProviderName: "google", // This name must match refresh call
// ... other config
})
// Then use same name in refresh
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
```
---
## 8. Complete Working Example
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
---
## Summary
OAuth2 refresh token functionality is **production-ready** with:
- ✅ Complete database schema with stored procedures
- ✅ Thread-safe Go implementation with mutex protection
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
- ✅ Comprehensive error handling
- ✅ Working code examples
- ✅ Full API documentation
- ✅ Security best practices implemented
**No additional implementation needed - feature is complete and functional.**

View File

@@ -0,0 +1,208 @@
# Passkey Authentication Quick Reference
## Overview
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
## Setup
### Database Schema
Run the passkey SQL schema (in database_schema.sql):
- Creates `user_passkey_credentials` table
- Adds stored procedures for passkey operations
### Go Code
```go
// Create passkey provider
passkeyProvider := security.NewDatabasePasskeyProvider(db,
security.DatabasePasskeyProviderOptions{
RPID: "example.com",
RPName: "Example App",
RPOrigin: "https://example.com",
Timeout: 60000,
})
// Create authenticator with passkey support
auth := security.NewDatabaseAuthenticatorWithOptions(db,
security.DatabaseAuthenticatorOptions{
PasskeyProvider: passkeyProvider,
})
// Or add passkey to existing authenticator
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
```
## Registration Flow
### Backend - Step 1: Begin Registration
```go
options, err := auth.BeginPasskeyRegistration(ctx,
security.PasskeyBeginRegistrationRequest{
UserID: 1,
Username: "alice",
DisplayName: "Alice Smith",
})
// Send options to client as JSON
```
### Frontend - Step 2: Create Credential
```javascript
// Convert options from server
options.challenge = base64ToArrayBuffer(options.challenge);
options.user.id = base64ToArrayBuffer(options.user.id);
// Create credential
const credential = await navigator.credentials.create({
publicKey: options
});
// Send credential back to server
```
### Backend - Step 3: Complete Registration
```go
credential, err := auth.CompletePasskeyRegistration(ctx,
security.PasskeyRegisterRequest{
UserID: 1,
Response: clientResponse,
ExpectedChallenge: storedChallenge,
CredentialName: "My iPhone",
})
```
## Authentication Flow
### Backend - Step 1: Begin Authentication
```go
options, err := auth.BeginPasskeyAuthentication(ctx,
security.PasskeyBeginAuthenticationRequest{
Username: "alice", // Optional for resident key
})
// Send options to client as JSON
```
### Frontend - Step 2: Get Credential
```javascript
// Convert options from server
options.challenge = base64ToArrayBuffer(options.challenge);
// Get credential
const credential = await navigator.credentials.get({
publicKey: options
});
// Send assertion back to server
```
### Backend - Step 3: Complete Authentication
```go
loginResponse, err := auth.LoginWithPasskey(ctx,
security.PasskeyLoginRequest{
Response: clientAssertion,
ExpectedChallenge: storedChallenge,
Claims: map[string]any{
"ip_address": "192.168.1.1",
"user_agent": "Mozilla/5.0...",
},
})
// Returns session token and user info
```
## Credential Management
### List Credentials
```go
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
```
### Update Credential Name
```go
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
```
### Delete Credential
```go
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
```
## HTTP Endpoints Example
### POST /api/passkey/register/begin
Request: `{user_id, username, display_name}`
Response: PasskeyRegistrationOptions
### POST /api/passkey/register/complete
Request: `{user_id, response, credential_name}`
Response: PasskeyCredential
### POST /api/passkey/login/begin
Request: `{username}` (optional)
Response: PasskeyAuthenticationOptions
### POST /api/passkey/login/complete
Request: `{response}`
Response: LoginResponse with session token
### GET /api/passkey/credentials
Response: Array of PasskeyCredential
### DELETE /api/passkey/credentials/{id}
Request: `{credential_id}`
Response: 204 No Content
## Database Stored Procedures
- `resolvespec_passkey_store_credential` - Store new credential
- `resolvespec_passkey_get_credential` - Get credential by ID
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
- `resolvespec_passkey_delete_credential` - Delete credential
- `resolvespec_passkey_update_name` - Update credential name
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
## Security Features
- **Clone Detection**: Sign counter validation detects credential cloning
- **Attestation Support**: Stores attestation type (none, indirect, direct)
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
- **Backup State**: Tracks if credential is backed up/synced
- **User Verification**: Supports preferred/required user verification
## Important Notes
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
4. **Browser Support**: Check browser compatibility for WebAuthn API.
5. **Relying Party ID**: Must match your domain exactly.
## Client-Side Helper Functions
```javascript
function base64ToArrayBuffer(base64) {
const binary = atob(base64);
const bytes = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i++) {
bytes[i] = binary.charCodeAt(i);
}
return bytes.buffer;
}
function arrayBufferToBase64(buffer) {
const bytes = new Uint8Array(buffer);
let binary = '';
for (let i = 0; i < bytes.length; i++) {
binary += String.fromCharCode(bytes[i]);
}
return btoa(binary);
}
```
## Testing
Run tests: `go test -v ./pkg/security -run Passkey`
All passkey functionality includes comprehensive tests using sqlmock.

View File

@@ -7,15 +7,16 @@
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
// OR: auth := security.NewHeaderAuthenticator()
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Step 2: Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
// Step 3: Setup and apply middleware
securityList := security.SetupSecurityProvider(handler, provider)
securityList, _ := security.SetupSecurityProvider(handler, provider)
router.Use(security.NewAuthMiddleware(securityList))
router.Use(security.SetSecurityMiddleware(securityList))
```
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
```go
// DatabaseAuthenticator uses these stored procedures:
resolvespec_login(jsonb) // Login with credentials
resolvespec_register(jsonb) // Register new user
resolvespec_logout(jsonb) // Invalidate session
resolvespec_session(text, text) // Validate session token
resolvespec_session_update(text, jsonb) // Update activity timestamp
@@ -256,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
"token": req.Token,
"user_id": req.UserID,
}).Error
// Invalidate session via stored procedure
return nil
}
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
@@ -403,11 +402,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
```
HTTP Request
NewAuthMiddleware → calls provider.Authenticate()
↓ (adds UserContext to context)
NewOptionalAuthMiddleware → calls provider.Authenticate()
↓ (adds UserContext or guest context; never 401)
SetSecurityMiddleware → adds SecurityList to context
Handler.Handle()
Handler.Handle() → resolves model
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
├─ SecurityDisabled → allow
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
└─ UserID == 0 → abort 401
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
@@ -502,10 +506,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
---
## Login/Logout Endpoints
## Login/Logout/Register Endpoints
```go
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
// Register
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
var req security.RegisterRequest
json.NewDecoder(r.Body).Decode(&req)
// Check if provider supports registration
registrable, ok := securityList.Provider().(security.Registrable)
if !ok {
http.Error(w, "Registration not supported", http.StatusNotImplemented)
return
}
resp, err := registrable.Register(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
json.NewEncoder(w).Encode(resp)
}).Methods("POST")
// Login
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
var req security.LoginRequest
@@ -670,15 +695,30 @@ http.Handle("/api/protected", authHandler)
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
http.Handle("/home", optionalHandler)
// Example handler
func myHandler(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
// Guest user
} else {
// Authenticated user
}
}
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
apiRouter.Use(security.SetSecurityMiddleware(securityList))
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
```
---
## Model-Level Access Control
```go
// Register model with rules (pkg/modelregistry)
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
SecurityDisabled: false, // skip all auth when true
CanPublicRead: true, // unauthenticated reads allowed
CanPublicCreate: false, // requires auth
CanPublicUpdate: false, // requires auth
CanPublicDelete: false, // requires auth
CanUpdate: true, // authenticated can update
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
})
// CheckModelAuthAllowed used automatically in BeforeHandle hook
// No code needed — call RegisterSecurityHooks and it's applied
```
---
@@ -707,6 +747,7 @@ meta, ok := security.GetUserMeta(ctx)
| File | Description |
|------|-------------|
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
| `examples.go` | Working provider implementations to copy |
| `setup_example.go` | 6 complete integration examples |
| `README.md` | Architecture overview and migration guide |

View File

@@ -6,6 +6,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Interface-Based** - Type-safe providers instead of callbacks
-**Login/Logout Support** - Built-in authentication lifecycle
-**Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
-**Composable** - Mix and match different providers
-**No Global State** - Each handler has its own security configuration
-**Testable** - Easy to mock and test
@@ -212,6 +213,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
// Note: Requires JWT library installation for token signing/verification
```
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
```go
baseAuth := security.NewDatabaseAuthenticator(db)
// Use in-memory provider (for testing)
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
// Or use database provider (for production)
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
// Requires: users table with totp fields, user_totp_backup_codes table
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
// Supports: TOTP codes, backup codes, QR code generation
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
```
### Column Security Providers
**DatabaseColumnSecurityProvider** - Loads rules from database:
@@ -334,7 +352,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}
```
## Two-Factor Authentication (2FA)
### Overview
- **Optional per-user** - Enable/disable 2FA individually
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
- **Backup codes** - One-time recovery codes with secure hashing
- **Clock skew** - Handles time differences between client/server
### Setup
```go
// 1. Wrap existing authenticator with 2FA support
baseAuth := security.NewDatabaseAuthenticator(db)
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
// 2. Use as normal authenticator
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
```
### Enable 2FA for User
```go
// 1. Initiate 2FA setup
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
// 2. User scans QR code with authenticator app
// Display secret.QRCodeURL as QR code image
// 3. User enters verification code from app
code := "123456" // From authenticator app
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
// 2FA is now enabled for this user
// 4. Store backup codes securely and show to user once
// Display: secret.BackupCodes (10 codes)
```
### Login Flow with 2FA
```go
// 1. User provides credentials
req := security.LoginRequest{
Username: "user@example.com",
Password: "password",
}
resp, err := tfaAuth.Login(ctx, req)
// 2. Check if 2FA required
if resp.Requires2FA {
// Prompt user for 2FA code
code := getUserInput() // From authenticator app or backup code
// 3. Login again with 2FA code
req.TwoFactorCode = code
resp, err = tfaAuth.Login(ctx, req)
// 4. Success - token is returned
token := resp.Token
}
```
### Manage 2FA
```go
// Disable 2FA
err := tfaAuth.Disable2FA(userID)
// Regenerate backup codes
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
// Check status
has2FA, err := tfaProvider.Get2FAStatus(userID)
```
### Custom 2FA Storage
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
```go
// Uses PostgreSQL stored procedures for all operations
db := setupDatabase()
// Run migrations from totp_database_schema.sql
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
// - Create user_totp_backup_codes table
// - Create resolvespec_totp_* stored procedures
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
```
**Option 2: Implement Custom Provider**
Implement `TwoFactorAuthProvider` for custom storage:
```go
type DBTwoFactorProvider struct {
db *gorm.DB
}
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
// Store secret and hashed backup codes in database
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
secret, hashCodes(backupCodes), userID).Error
}
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var secret string
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
return secret, err
}
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
```
### Configuration
```go
config := &security.TwoFactorConfig{
Algorithm: "SHA256", // SHA1, SHA256, SHA512
Digits: 8, // 6 or 8
Period: 30, // Seconds per code
SkewWindow: 2, // Accept codes ±2 periods
}
totp := security.NewTOTPGenerator(config)
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
```
### API Response Structure
```go
// LoginResponse with 2FA
type LoginResponse struct {
Token string `json:"token"`
Requires2FA bool `json:"requires_2fa"`
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
User *UserContext `json:"user"`
}
// TwoFactorSecret for setup
type TwoFactorSecret struct {
Secret string `json:"secret"` // Base32 encoded
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
}
// UserContext includes 2FA status
type UserContext struct {
UserID int `json:"user_id"`
TwoFactorEnabled bool `json:"two_factor_enabled"`
// ... other fields
}
```
### Security Best Practices
- **Store secrets encrypted** - Never store TOTP secrets in plain text
- **Hash backup codes** - Use SHA-256 before storing
- **Rate limit** - Limit 2FA verification attempts
- **Require password** - Always verify password before disabling 2FA
- **Show backup codes once** - Display only during setup/regeneration
- **Log 2FA events** - Track enable/disable/failed attempts
- **Mark codes as used** - Backup codes are single-use only
json.NewEncoder(w).Encode(resp)
} else {
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
@@ -558,14 +751,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
```
HTTP Request
NewAuthMiddleware (security package)
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
├─ Calls provider.Authenticate(request)
Adds UserContext to context
On success: adds authenticated UserContext to context
└─ On failure: adds guest UserContext (UserID=0) to context
SetSecurityMiddleware (security package)
└─ Adds SecurityList to context
Spec Handler (restheadspec/funcspec/resolvespec)
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
└─ Resolves schema + entity + model from request
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
├─ Adapts spec's HookContext → SecurityContext
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
│ ├─ Loads model rules from context or registry
│ ├─ SecurityDisabled → allow
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
│ └─ UserID == 0 → 401 unauthorized
└─ On error: aborts with 401
BeforeRead Hook (registered by spec)
├─ Adapts spec's HookContext → SecurityContext
@@ -591,7 +795,8 @@ HTTP Response (secured data)
```
**Key Points:**
- Security package is spec-agnostic and provides core logic
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
- Each spec registers its own hooks that adapt to SecurityContext
- Security rules are loaded once and cached for the request
- Row security is applied to the query (database level)
@@ -809,15 +1014,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
}
```
## Model-Level Access Control
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
```go
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
SecurityDisabled: false, // true = skip all auth checks
CanPublicRead: true, // unauthenticated GET allowed
CanPublicCreate: false, // requires auth
CanPublicUpdate: false, // requires auth
CanPublicDelete: false, // requires auth
CanUpdate: true, // authenticated users can update
CanDelete: false, // authenticated users cannot delete
})
```
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
1. `SecurityDisabled` → allow all
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
3. Guest (UserID == 0) → return 401
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
---
## Middleware and Handler API
### NewAuthMiddleware
Standard middleware that authenticates all requests:
Standard middleware that authenticates all requests and returns 401 on failure:
```go
router.Use(security.NewAuthMiddleware(securityList))
```
### NewOptionalAuthMiddleware
Middleware for spec routes — always continues; sets guest context on auth failure:
```go
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
apiRouter.Use(security.SetSecurityMiddleware(securityList))
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
```
Routes can skip authentication using the `SkipAuth` helper:
```go

File diff suppressed because it is too large Load Diff

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
}
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"reflect"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// SecurityContext is a generic interface that any spec can implement to integrate with security features
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
return applyColumnSecurity(secCtx, securityList)
}
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
func checkModelUpdateAllowed(secCtx SecurityContext) error {
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
if !ok {
schema := secCtx.GetSchema()
entity := secCtx.GetEntity()
var err error
if schema != "" {
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
}
if err != nil || schema == "" {
rules, err = modelregistry.GetModelRulesByName(entity)
}
if err != nil {
return nil // model not registered, allow by default
}
}
if !rules.CanUpdate {
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
}
return nil
}
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
func checkModelDeleteAllowed(secCtx SecurityContext) error {
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
if !ok {
schema := secCtx.GetSchema()
entity := secCtx.GetEntity()
var err error
if schema != "" {
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
}
if err != nil || schema == "" {
rules, err = modelregistry.GetModelRulesByName(entity)
}
if err != nil {
return nil // model not registered, allow by default
}
}
if !rules.CanDelete {
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
}
return nil
}
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
// model rules and the current user's authentication state. It is intended for use in
// a BeforeHandle hook, fired after model resolution.
//
// Logic:
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
// 2. SecurityDisabled → allow.
// 3. operation == "read" && CanPublicRead → allow.
// 4. operation == "create" && CanPublicCreate → allow.
// 5. operation == "update" && CanPublicUpdate → allow.
// 6. operation == "delete" && CanPublicDelete → allow.
// 7. Guest (UserID == 0) → return "authentication required".
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
if !ok {
schema := secCtx.GetSchema()
entity := secCtx.GetEntity()
var err error
if schema != "" {
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
}
if err != nil || schema == "" {
rules, err = modelregistry.GetModelRulesByName(entity)
}
if err != nil {
// Model not registered - fall through to auth check
userID, _ := secCtx.GetUserID()
if userID == 0 {
return fmt.Errorf("authentication required")
}
return nil
}
}
if rules.SecurityDisabled {
return nil
}
if operation == "read" && rules.CanPublicRead {
return nil
}
if operation == "create" && rules.CanPublicCreate {
return nil
}
if operation == "update" && rules.CanPublicUpdate {
return nil
}
if operation == "delete" && rules.CanPublicDelete {
return nil
}
userID, _ := secCtx.GetUserID()
if userID == 0 {
return fmt.Errorf("authentication required")
}
return nil
}
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
return checkModelUpdateAllowed(secCtx)
}
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
return checkModelDeleteAllowed(secCtx)
}
// Helper functions
func contains(s, substr string) bool {

View File

@@ -7,33 +7,48 @@ import (
// UserContext holds authenticated user information
type UserContext struct {
UserID int `json:"user_id"`
UserName string `json:"user_name"`
UserLevel int `json:"user_level"`
SessionID string `json:"session_id"`
SessionRID int64 `json:"session_rid"`
RemoteID string `json:"remote_id"`
Roles []string `json:"roles"`
Email string `json:"email"`
Claims map[string]any `json:"claims"`
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
UserID int `json:"user_id"`
UserName string `json:"user_name"`
UserLevel int `json:"user_level"`
SessionID string `json:"session_id"`
SessionRID int64 `json:"session_rid"`
RemoteID string `json:"remote_id"`
Roles []string `json:"roles"`
Email string `json:"email"`
Claims map[string]any `json:"claims"`
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
}
// LoginRequest contains credentials for login
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Claims map[string]any `json:"claims"` // Additional login data
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
Username string `json:"username"`
Password string `json:"password"`
TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
Claims map[string]any `json:"claims"` // Additional login data
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
}
// RegisterRequest contains information for new user registration
type RegisterRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Email string `json:"email"`
UserLevel int `json:"user_level"`
Roles []string `json:"roles"`
Claims map[string]any `json:"claims"` // Additional registration data
Meta map[string]any `json:"meta"` // Additional metadata
}
// LoginResponse contains the result of a login attempt
type LoginResponse struct {
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
User *UserContext `json:"user"`
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
User *UserContext `json:"user"`
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
}
// LogoutRequest contains information for logout
@@ -55,6 +70,12 @@ type Authenticator interface {
Authenticate(r *http.Request) (*UserContext, error)
}
// Registrable allows providers to support user registration
type Registrable interface {
// Register creates a new user account
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
}
// ColumnSecurityProvider handles column-level security (masking/hiding)
type ColumnSecurityProvider interface {
// GetColumnSecurity loads column security rules for a user and entity

81
pkg/security/keystore.go Normal file
View File

@@ -0,0 +1,81 @@
package security
import (
"context"
"crypto/sha256"
"encoding/hex"
"time"
)
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
// Used by all keystore implementations to hash raw keys before storage or lookup.
func hashSHA256Hex(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// KeyType identifies the category of an auth key.
type KeyType string
const (
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
KeyTypeJWTSecret KeyType = "jwt_secret"
// KeyTypeHeaderAPI is a static API key sent via a request header.
KeyTypeHeaderAPI KeyType = "header_api"
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
KeyTypeOAuth2 KeyType = "oauth2"
// KeyTypeGenericAPI is a generic application API key.
KeyTypeGenericAPI KeyType = "api"
)
// UserKey represents a single named auth key belonging to a user.
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
type UserKey struct {
ID int64 `json:"id"`
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
IsActive bool `json:"is_active"`
}
// CreateKeyRequest specifies the parameters for a new key.
type CreateKeyRequest struct {
UserID int
KeyType KeyType
Name string
Scopes []string
Meta map[string]any
ExpiresAt *time.Time
}
// CreateKeyResponse is returned exactly once when a key is created.
// The caller is responsible for persisting RawKey; it is not stored anywhere.
type CreateKeyResponse struct {
Key UserKey
RawKey string // crypto/rand 32 bytes, base64url-encoded
}
// KeyStore manages per-user auth keys with pluggable storage backends.
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
type KeyStore interface {
// CreateKey generates a new key, stores its hash, and returns the raw key once.
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
// GetUserKeys returns all active, non-expired keys for a user.
// Pass an empty KeyType to return all types.
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
// DeleteKey soft-deletes a key by ID after verifying ownership.
DeleteKey(ctx context.Context, userID int, keyID int64) error
// ValidateKey checks a raw key, returns the matching UserKey on success.
// The implementation hashes the raw key before any lookup.
// Pass an empty KeyType to accept any type.
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
}

View File

@@ -0,0 +1,97 @@
package security
import (
"context"
"fmt"
"net/http"
"strings"
)
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
// is managed directly through the KeyStore.
//
// Key extraction order:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
type KeyStoreAuthenticator struct {
keyStore KeyStore
keyType KeyType // empty = accept any type
}
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
// Pass an empty keyType to accept keys of any type.
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
}
// Login is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// Logout is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return nil
}
// Authenticate extracts an API key from the request and validates it against the KeyStore.
// Returns a UserContext built from the matching UserKey on success.
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
rawKey := extractAPIKey(r)
if rawKey == "" {
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
}
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return userKeyToUserContext(userKey), nil
}
// extractAPIKey extracts a raw key from the request using the following precedence:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
func extractAPIKey(r *http.Request) string {
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return strings.TrimSpace(after)
}
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
return strings.TrimSpace(after)
}
}
return strings.TrimSpace(r.Header.Get("X-API-Key"))
}
// userKeyToUserContext converts a UserKey into a UserContext.
// Scopes are mapped to Roles. Key type and name are stored in Claims.
func userKeyToUserContext(k *UserKey) *UserContext {
claims := map[string]any{
"key_type": string(k.KeyType),
"key_name": k.Name,
}
meta := k.Meta
if meta == nil {
meta = map[string]any{}
}
roles := k.Scopes
if roles == nil {
roles = []string{}
}
return &UserContext{
UserID: k.UserID,
SessionID: fmt.Sprintf("key:%d", k.ID),
Roles: roles,
Claims: claims,
Meta: meta,
}
}

View File

@@ -0,0 +1,149 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"time"
)
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
//
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
type ConfigKeyStore struct {
mu sync.RWMutex
keys []UserKey
next int64 // monotonic ID counter for runtime-created keys (atomic)
}
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
// Pass nil or an empty slice to start with no pre-loaded keys.
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
var maxID int64
copied := make([]UserKey, len(keys))
copy(copied, keys)
for i := range copied {
if copied[i].CreatedAt.IsZero() {
copied[i].IsActive = true
copied[i].CreatedAt = time.Now()
}
if copied[i].ID > maxID {
maxID = copied[i].ID
}
}
return &ConfigKeyStore{keys: copied, next: maxID}
}
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
id := atomic.AddInt64(&s.next, 1)
key := UserKey{
ID: id,
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
CreatedAt: time.Now(),
IsActive: true,
}
s.mu.Lock()
s.keys = append(s.keys, key)
s.mu.Unlock()
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
now := time.Now()
s.mu.RLock()
defer s.mu.RUnlock()
var result []UserKey
for i := range s.keys {
k := &s.keys[i]
if k.UserID != userID || !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
result = append(result, *k)
}
return result, nil
}
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
if s.keys[i].ID == keyID {
if s.keys[i].UserID != userID {
return fmt.Errorf("key not found or permission denied")
}
s.keys[i].IsActive = false
return nil
}
}
return fmt.Errorf("key not found")
}
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
// Uses constant-time comparison to prevent timing side-channels.
// Pass an empty KeyType to accept any type.
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
hashBytes, _ := hex.DecodeString(hash)
now := time.Now()
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
k := &s.keys[i]
if !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
stored, _ := hex.DecodeString(k.KeyHash)
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
continue
}
k.LastUsedAt = &now
result := *k
return &result, nil
}
return nil, fmt.Errorf("invalid or expired key")
}

View File

@@ -0,0 +1,256 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
type DatabaseKeyStoreOptions struct {
// Cache is an optional cache instance. If nil, uses the default cache.
Cache *cache.Cache
// CacheTTL is the duration to cache ValidateKey results.
// Default: 2 minutes.
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
// All DB operations go through configurable procedure names; the raw key is
// never passed to the database.
//
// See keystore_schema.sql for the required table and procedure definitions.
//
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
o := DatabaseKeyStoreOptions{}
if len(opts) > 0 {
o = opts[0]
}
if o.CacheTTL == 0 {
o.CacheTTL = 2 * time.Minute
}
c := o.Cache
if c == nil {
c = cache.GetDefaultCache()
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
type createRequest struct {
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"`
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
reqJSON, err := json.Marshal(createRequest{
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
}
var key UserKey
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse created key: %w", err)
}
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
var success bool
var errorMsg sql.NullString
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
}
var keys []UserKey
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
return nil, fmt.Errorf("failed to parse user keys: %w", err)
}
}
if keys == nil {
keys = []UserKey{}
}
return keys, nil
}
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
// The delete procedure returns the key_hash so no separate lookup is needed.
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
var success bool
var errorMsg sql.NullString
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
return errors.New(nullStringOr(errorMsg, "delete key failed"))
}
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
}
return nil
}
// ValidateKey hashes the raw key and calls the validate procedure.
// Results are cached for CacheTTL to reduce DB load on hot paths.
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
cacheKey := keystoreCacheKey(hash)
if ks.cache != nil {
var cached UserKey
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
if cached.IsActive {
return &cached, nil
}
return nil, errors.New("invalid or expired key")
}
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
}
var key UserKey
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse validated key: %w", err)
}
if ks.cache != nil {
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
}
return &key, nil
}
func keystoreCacheKey(hash string) string {
return "keystore:validate:" + hash
}
// nullStringOr returns s.String if valid, otherwise the fallback.
func nullStringOr(s sql.NullString, fallback string) string {
if s.Valid && s.String != "" {
return s.String
}
return fallback
}

View File

@@ -0,0 +1,187 @@
-- Keystore schema for per-user auth keys
-- Apply alongside database_schema.sql (requires the users table)
CREATE TABLE IF NOT EXISTS user_keys (
id BIGSERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_type VARCHAR(50) NOT NULL,
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
name VARCHAR(255) NOT NULL DEFAULT '',
scopes TEXT, -- JSON array, e.g. '["read","write"]'
meta JSONB,
expires_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP,
is_active BOOLEAN DEFAULT true
);
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
-- resolvespec_keystore_get_user_keys
-- Returns all active, non-expired keys for a user.
-- Pass empty p_key_type to return all key types.
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
p_user_id INTEGER,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_keys JSONB;
BEGIN
SELECT COALESCE(
jsonb_agg(
jsonb_build_object(
'id', k.id,
'user_id', k.user_id,
'key_type', k.key_type,
'name', k.name,
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
'meta', COALESCE(k.meta, '{}'::jsonb),
'expires_at', k.expires_at,
'created_at', k.created_at,
'last_used_at', k.last_used_at,
'is_active', k.is_active
)
),
'[]'::jsonb
)
INTO v_keys
FROM user_keys k
WHERE k.user_id = p_user_id
AND k.is_active = true
AND (k.expires_at IS NULL OR k.expires_at > NOW())
AND (p_key_type = '' OR k.key_type = p_key_type);
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_create_key
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
-- Returns the created key record (without key_hash).
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
p_request JSONB
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_id BIGINT;
v_created_at TIMESTAMP;
v_key JSONB;
BEGIN
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
VALUES (
(p_request->>'user_id')::INTEGER,
p_request->>'key_type',
p_request->>'key_hash',
COALESCE(p_request->>'name', ''),
p_request->>'scopes',
p_request->'meta',
CASE WHEN p_request->>'expires_at' IS NOT NULL
THEN (p_request->>'expires_at')::TIMESTAMP
ELSE NULL
END
)
RETURNING id, created_at INTO v_id, v_created_at;
v_key := jsonb_build_object(
'id', v_id,
'user_id', (p_request->>'user_id')::INTEGER,
'key_type', p_request->>'key_type',
'name', COALESCE(p_request->>'name', ''),
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
THEN (p_request->>'scopes')::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
'expires_at', p_request->>'expires_at',
'created_at', v_created_at,
'is_active', true
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_delete_key
-- Soft-deletes a key (is_active = false) after verifying ownership.
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
p_user_id INTEGER,
p_key_id BIGINT
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
LANGUAGE plpgsql AS $$
DECLARE
v_hash TEXT;
BEGIN
UPDATE user_keys
SET is_active = false
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
RETURNING key_hash INTO v_hash;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
RETURN;
END IF;
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
END;
$$;
-- resolvespec_keystore_validate_key
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
-- updates last_used_at, and returns the key record.
-- p_key_type can be empty to accept any key type.
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
p_key_hash TEXT,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_key_rec user_keys%ROWTYPE;
v_key JSONB;
BEGIN
SELECT * INTO v_key_rec
FROM user_keys
WHERE key_hash = p_key_hash
AND is_active = true
AND (expires_at IS NULL OR expires_at > NOW())
AND (p_key_type = '' OR key_type = p_key_type);
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
RETURN;
END IF;
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
v_key := jsonb_build_object(
'id', v_key_rec.id,
'user_id', v_key_rec.user_id,
'key_type', v_key_rec.key_type,
'name', v_key_rec.name,
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
THEN v_key_rec.scopes::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
'expires_at', v_key_rec.expires_at,
'created_at', v_key_rec.created_at,
'last_used_at', NOW(),
'is_active', v_key_rec.is_active
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;

View File

@@ -0,0 +1,61 @@
package security
import "fmt"
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
type KeyStoreSQLNames struct {
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
CreateKey string // default: "resolvespec_keystore_create_key"
DeleteKey string // default: "resolvespec_keystore_delete_key"
ValidateKey string // default: "resolvespec_keystore_validate_key"
}
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
return &KeyStoreSQLNames{
GetUserKeys: "resolvespec_keystore_get_user_keys",
CreateKey: "resolvespec_keystore_create_key",
DeleteKey: "resolvespec_keystore_delete_key",
ValidateKey: "resolvespec_keystore_validate_key",
}
}
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.GetUserKeys != "" {
merged.GetUserKeys = override.GetUserKeys
}
if override.CreateKey != "" {
merged.CreateKey = override.CreateKey
}
if override.DeleteKey != "" {
merged.DeleteKey = override.DeleteKey
}
if override.ValidateKey != "" {
merged.ValidateKey = override.ValidateKey
}
return &merged
}
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
fields := map[string]string{
"GetUserKeys": names.GetUserKeys,
"CreateKey": names.CreateKey,
"DeleteKey": names.DeleteKey,
"ValidateKey": names.ValidateKey,
}
for field, val := range fields {
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
}
}
return nil
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"net/http"
"strconv"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// contextKey is a custom type for context keys to avoid collisions
@@ -23,6 +25,7 @@ const (
UserMetaKey contextKey = "user_meta"
SkipAuthKey contextKey = "skip_auth"
OptionalAuthKey contextKey = "optional_auth"
ModelRulesKey contextKey = "model_rules"
)
// SkipAuth returns a context with skip auth flag set to true
@@ -136,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
})
}
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
// On auth failure, a guest user context is set instead of returning 401.
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
// after model resolution.
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
userCtx, err := provider.Authenticate(r)
if err != nil {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
next.ServeHTTP(w, setUserContext(r, userCtx))
})
}
}
// NewAuthMiddleware creates an authentication middleware with the given security list
// This middleware extracts user authentication from the request and adds it to context
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
@@ -182,6 +210,68 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
}
}
// NewModelAuthMiddleware creates authentication middleware that respects ModelRules for the given model name.
// It first checks if ModelRules are set for the model:
// - If SecurityDisabled is true, authentication is skipped and a guest context is set.
// - Otherwise, all checks from NewAuthMiddleware apply (SkipAuthKey, provider check, OptionalAuthKey, Authenticate).
//
// If the model is not found in any registry, the middleware falls back to standard NewAuthMiddleware behaviour.
func NewModelAuthMiddleware(securityList *SecurityList, modelName string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check ModelRules first
if rules, err := modelregistry.GetModelRulesByName(modelName); err == nil {
// Store rules in context for downstream use (e.g., security hooks)
r = r.WithContext(context.WithValue(r.Context(), ModelRulesKey, rules))
if rules.SecurityDisabled {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
isRead := r.Method == http.MethodGet || r.Method == http.MethodHead
isUpdate := r.Method == http.MethodPut || r.Method == http.MethodPatch
if (isRead && rules.CanPublicRead) || (isUpdate && rules.CanPublicUpdate) {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
}
// Check if this route should skip authentication
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
// Get the security provider
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
// Check if this route has optional authentication
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
// Try to authenticate
userCtx, err := provider.Authenticate(r)
if err != nil {
if optional {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, setUserContext(r, userCtx))
})
}
}
// SetSecurityMiddleware adds security context to requests
// This middleware should be applied after AuthMiddleware
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
@@ -366,6 +456,131 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
return meta, ok
}
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
// All fields are optional; sensible secure defaults are applied when omitted.
type SessionCookieOptions struct {
// Name is the cookie name. Defaults to "session_token".
Name string
// Path is the cookie path. Defaults to "/".
Path string
// Domain restricts the cookie to a specific domain. Empty means current host.
Domain string
// Secure sets the Secure flag. Defaults to true.
// Set to false only in local development over HTTP.
Secure *bool
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
SameSite http.SameSite
}
func (o SessionCookieOptions) name() string {
if o.Name != "" {
return o.Name
}
return "session_token"
}
func (o SessionCookieOptions) path() string {
if o.Path != "" {
return o.Path
}
return "/"
}
func (o SessionCookieOptions) secure() bool {
if o.Secure != nil {
return *o.Secure
}
return true
}
func (o SessionCookieOptions) sameSite() http.SameSite {
if o.SameSite != 0 {
return o.SameSite
}
return http.SameSiteLaxMode
}
// SetSessionCookie writes the session_token cookie to the response after a successful login.
// Call this immediately after a successful Authenticator.Login() call.
//
// Example:
//
// resp, err := auth.Login(r.Context(), req)
// if err != nil { ... }
// security.SetSessionCookie(w, resp)
// json.NewEncoder(w).Encode(resp)
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
maxAge := 0
if loginResp.ExpiresIn > 0 {
maxAge = int(loginResp.ExpiresIn)
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: loginResp.Token,
Path: o.path(),
Domain: o.Domain,
MaxAge: maxAge,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
//
// Example:
//
// token := security.GetSessionCookie(r)
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
cookie, err := r.Cookie(o.name())
if err != nil {
return ""
}
return cookie.Value
}
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
// Call this after a successful Authenticator.Logout() call.
//
// Example:
//
// err := auth.Logout(r.Context(), req)
// if err != nil { ... }
// security.ClearSessionCookie(w)
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: "",
Path: o.path(),
Domain: o.Domain,
MaxAge: -1,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)
return rules, ok
}
// // Handler adapters for resolvespec/restheadspec compatibility
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions

View File

@@ -0,0 +1,615 @@
package security
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
)
// Example: OAuth2 Authentication with Google
func ExampleOAuth2Google() {
db, _ := sql.Open("postgres", "connection-string")
// Create OAuth2 authenticator for Google
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Login endpoint - redirects to Google
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
// Callback endpoint - handles Google response
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
// Return user info as JSON
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 Authentication with GitHub
func ExampleOAuth2GitHub() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGitHubAuthenticator(
"your-github-client-id",
"your-github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
router := mux.NewRouter()
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Custom OAuth2 Provider
func ExampleOAuth2Custom() {
db, _ := sql.Open("postgres", "connection-string")
// Custom OAuth2 provider configuration
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
ClientID: "your-client-id",
ClientSecret: "your-client-secret",
RedirectURL: "http://localhost:8080/auth/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://your-provider.com/oauth/authorize",
TokenURL: "https://your-provider.com/oauth/token",
UserInfoURL: "https://your-provider.com/oauth/userinfo",
ProviderName: "custom-provider",
// Custom user info parser
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
// Extract custom fields from your provider
return &UserContext{
UserName: userInfo["username"].(string),
Email: userInfo["email"].(string),
RemoteID: userInfo["id"].(string),
UserLevel: 1,
Roles: []string{"user"},
Claims: userInfo,
}, nil
},
})
router := mux.NewRouter()
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Multi-Provider OAuth2 with Security Integration
func ExampleOAuth2MultiProvider() {
db, _ := sql.Open("postgres", "connection-string")
// Create OAuth2 authenticators for multiple providers
googleAuth := NewGoogleAuthenticator(
"google-client-id",
"google-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
githubAuth := NewGitHubAuthenticator(
"github-client-id",
"github-client-secret",
"http://localhost:8080/auth/github/callback",
db,
)
// Create column and row security providers
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
router := mux.NewRouter()
// Google OAuth2 routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := googleAuth.OAuth2GenerateState()
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// GitHub OAuth2 routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := githubAuth.OAuth2GenerateState()
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// Use Google auth for protected routes (or GitHub - both work)
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
// Protected route with authentication
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 with Token Refresh
func ExampleOAuth2TokenRefresh() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
// Refresh token endpoint
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
var req struct {
RefreshToken string `json:"refresh_token"`
Provider string `json:"provider"` // "google", "github", etc.
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// Default to google if not specified
if req.Provider == "" {
req.Provider = "google"
}
// Use OAuth2-specific refresh method
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set new session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
_ = json.NewEncoder(w).Encode(loginResp)
})
_ = http.ListenAndServe(":8080", router)
}
// Example: OAuth2 Logout
func ExampleOAuth2Logout() {
db, _ := sql.Open("postgres", "connection-string")
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
router := mux.NewRouter()
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if token == "" {
cookie, err := r.Cookie("session_token")
if err == nil {
token = cookie.Value
}
}
if token != "" {
// Get user ID from session
userCtx, err := oauth2Auth.Authenticate(r)
if err == nil {
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
Token: token,
UserID: userCtx.UserID,
})
}
}
// Clear cookie
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Logged out successfully"))
})
_ = http.ListenAndServe(":8080", router)
}
// Example: Complete OAuth2 Integration with Database Setup
func ExampleOAuth2Complete() {
db, _ := sql.Open("postgres", "connection-string")
// Create tables (run once)
setupOAuth2Tables(db)
// Create OAuth2 authenticator
oauth2Auth := NewGoogleAuthenticator(
"your-client-id",
"your-client-secret",
"http://localhost:8080/auth/google/callback",
db,
)
// Create security providers
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
router := mux.NewRouter()
// Public routes
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
})
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := oauth2Auth.OAuth2GenerateState()
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: loginResp.Token,
Path: "/",
MaxAge: int(loginResp.ExpiresIn),
HttpOnly: true,
})
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
})
// Protected routes
protectedRouter := router.PathPrefix("/").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
})
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
Token: userCtx.SessionID,
UserID: userCtx.UserID,
})
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
})
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
})
_ = http.ListenAndServe(":8080", router)
}
func setupOAuth2Tables(db *sql.DB) {
// Create tables from database_schema.sql
// This is a helper function - in production, use migrations
ctx := context.Background()
// Create users table if not exists
_, _ = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255),
user_level INTEGER DEFAULT 0,
roles VARCHAR(500),
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP,
remote_id VARCHAR(255),
auth_provider VARCHAR(50)
)
`)
// Create user_sessions table (used for both regular and OAuth2 sessions)
_, _ = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45),
user_agent TEXT,
access_token TEXT,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
auth_provider VARCHAR(50)
)
`)
}
// Example: All OAuth2 Providers at Once
func ExampleOAuth2AllProviders() {
db, _ := sql.Open("postgres", "connection-string")
// Create authenticator with ALL OAuth2 providers
auth := NewDatabaseAuthenticator(db).
WithOAuth2(OAuth2Config{
ClientID: "google-client-id",
ClientSecret: "google-client-secret",
RedirectURL: "http://localhost:8080/auth/google/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
}).
WithOAuth2(OAuth2Config{
ClientID: "github-client-id",
ClientSecret: "github-client-secret",
RedirectURL: "http://localhost:8080/auth/github/callback",
Scopes: []string{"user:email"},
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
UserInfoURL: "https://api.github.com/user",
ProviderName: "github",
}).
WithOAuth2(OAuth2Config{
ClientID: "microsoft-client-id",
ClientSecret: "microsoft-client-secret",
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
ProviderName: "microsoft",
}).
WithOAuth2(OAuth2Config{
ClientID: "facebook-client-id",
ClientSecret: "facebook-client-secret",
RedirectURL: "http://localhost:8080/auth/facebook/callback",
Scopes: []string{"email"},
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
ProviderName: "facebook",
})
// Get list of configured providers
providers := auth.OAuth2GetProviders()
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
router := mux.NewRouter()
// Google routes
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("google", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// GitHub routes
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("github", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Microsoft routes
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Facebook routes
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
state, _ := auth.OAuth2GenerateState()
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
})
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(loginResp)
})
// Create security list for protected routes
colSec := NewDatabaseColumnSecurityProvider(db)
rowSec := NewDatabaseRowSecurityProvider(db)
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := NewSecurityList(provider)
// Protected routes work for ALL OAuth2 providers + regular sessions
protectedRouter := router.PathPrefix("/api").Subrouter()
protectedRouter.Use(NewAuthMiddleware(securityList))
protectedRouter.Use(SetSecurityMiddleware(securityList))
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
userCtx, _ := GetUserContext(r.Context())
_ = json.NewEncoder(w).Encode(userCtx)
})
_ = http.ListenAndServe(":8080", router)
}

Some files were not shown because too many files have changed in this diff Show More