Compare commits

..

83 Commits

Author SHA1 Message Date
Hein
4ee6ef0955 chore: 🙈 update git ignore 2026-01-02 16:20:56 +02:00
Hein
6f05f15ff6 feat(dbmanager): Database connection Manager 2026-01-02 16:19:33 +02:00
Hein
443a672fcb feat(packages): Prometheus publish events 2026-01-02 13:48:31 +02:00
Hein
c2fcc5aaff chore(spectypes): ✏️ edded scope 2026-01-02 12:17:37 +02:00
Hein
6664a4e2d2 fix(spectypes): 🐛 uuid value was not parsed as string and caused sql syntax error 2026-01-02 12:16:29 +02:00
037bd4c05e chore: 🦺 actions updated
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m45s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m23s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m16s
Build , Vet Test, and Lint / Build (push) Successful in -25m28s
Tests / Unit Tests (push) Successful in -25m45s
Tests / Integration Tests (push) Failing after -25m51s
2025-12-31 19:23:09 +02:00
Hein
e77468a239 refactor: ♻️ ach package accepts its configuration as a parameter rather than reading from global config
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m44s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m19s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m18s
Build , Vet Test, and Lint / Build (push) Successful in -25m30s
Tests / Unit Tests (push) Failing after -25m51s
Tests / Integration Tests (push) Failing after -25m53s
2025-12-31 12:39:59 +02:00
Hein
82d84435f2 Fixed version release auto inc
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m36s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m5s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m16s
Build , Vet Test, and Lint / Build (push) Successful in -25m29s
Tests / Unit Tests (push) Failing after -25m48s
Tests / Integration Tests (push) Failing after -25m53s
2025-12-31 09:41:23 +02:00
Bitech Systems
b99b08430e Merge pull request #16 from bitechdev/fix-funcapi
fix: 🔒 Fixed funcapi possible injections
2025-12-31 09:38:44 +02:00
Hein
fae9a082bd refactor: 🚨 linting issues 2025-12-31 09:37:12 +02:00
Hein Puth (Warkanum)
191822b91c Merge pull request #17 from bitechdev/copilot/sub-pr-16
fix: Comprehensive SQL injection protection in funcspec handlers
2025-12-31 09:35:49 +02:00
copilot-swe-agent[bot]
a6a17d019f fix: Optimized regex performance and added backslash escaping for LIKE patterns
- Added backslash escaping to LIKE pattern sanitization
- Optimized ValidSQL select mode with single pre-compiled regex
- All funcspec tests pass (except pre-existing TestReplaceMetaVariables)

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:28:16 +00:00
copilot-swe-agent[bot]
a7cc42044b fix: Improved SQL injection protection with proper handling
- Fixed IN clause to conditionally quote only string values (not numeric)
- Fixed LIKE pattern sanitization to preserve wildcards while preventing injection
- Improved dangerous pattern removal with case-insensitive regex while preserving case
- All funcspec tests now pass (except pre-existing TestReplaceMetaVariables)

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:24:56 +00:00
copilot-swe-agent[bot]
8cdc353029 fix: Fixed SQL quoting for empty/zero values and sanitized match filter
- Sanitize colval immediately after extraction in match= filter
- Fixed empty/zero value handling to use proper SQL literals (0 vs '')
- Applied proper quoting for string vs numeric comparisons
- Fixed x-fieldfilter handlers for proper value handling

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:22:42 +00:00
copilot-swe-agent[bot]
6528e94297 fix: Improved SQL injection protections based on code review
- Fixed backslash escaping order in colvalue mode
- Added proper quoting for IN clause values
- Simplified dangerous pattern matching with case-insensitive approach
- All funcspec tests pass (except pre-existing test failure)

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:20:56 +00:00
copilot-swe-agent[bot]
f711bf38d2 fix: Enhanced SQL injection protection in funcspec
- Added sanitization for path parameters in mergePathParams
- Added sanitization for query parameters with p- prefix in mergeQueryParams
- Added sanitization for header parameters in mergeHeaderParams
- Fixed IN clause to sanitize all values individually
- Improved ValidSQL function with better escaping and more injection patterns
- Added backslash escaping to colvalue mode
- Extended dangerous keyword list in select mode

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:19:53 +00:00
copilot-swe-agent[bot]
44356d8750 Initial plan 2025-12-31 07:11:57 +00:00
Hein
caf85cf558 fix: 🔒 Fixed funcapi possible injections 2025-12-31 09:09:16 +02:00
Hein Puth (Warkanum)
2e1547ec65 Merge pull request #15 from bitechdev/feature-staticweb
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -21m25s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -21m19s
Build , Vet Test, and Lint / Build (push) Successful in -25m26s
Build , Vet Test, and Lint / Lint Code (push) Successful in -24m49s
Tests / Unit Tests (push) Failing after -25m46s
Tests / Integration Tests (push) Failing after -25m47s
feature: staticweb
2025-12-30 17:49:33 +02:00
Hein
49cdc6f17b fix: lint issues after merge 2025-12-30 17:46:33 +02:00
Hein
0bd653820c Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into feature-staticweb 2025-12-30 17:45:17 +02:00
Hein
9209193157 fix: lint issues and docs 2025-12-30 17:44:57 +02:00
Hein Puth (Warkanum)
b8c44c5a99 Merge pull request #14 from bitechdev/copilot/fix-api-null-response
Fix API returning null instead of empty array when no records found
2025-12-30 17:35:02 +02:00
Hein
28fd88fff1 staticweb package for easier static web server hosting 2025-12-30 17:31:07 +02:00
copilot-swe-agent[bot]
be38341383 Fix formatting issues with gofmt
- Removed trailing whitespace
- Fixed tab/space alignment in struct definitions
- All tests still passing

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 15:30:23 +00:00
copilot-swe-agent[bot]
fab744b878 Add clarifying comments about X-No-Data-Found header timing
- Added comments explaining why X-No-Data-Found is set before normalization
- Header reflects database query result, not final response format
- Clarifies that normalizeResultArray doesn't affect header logic
- All tests passing

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 14:04:16 +00:00
copilot-swe-agent[bot]
5ad2bd3a78 Improve test robustness - use explicit flag instead of string comparison
- Changed test to use shouldBeEmptyArr flag instead of hardcoded name comparison
- Makes test more maintainable and less fragile
- All tests still passing

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 14:02:42 +00:00
copilot-swe-agent[bot]
333fe158e9 Address code review feedback - improve data length calculation clarity
- Simplified data length calculation logic in sendFormattedResponse
- Simplified data length calculation logic in sendResponseWithOptions
- Calculate dataLen after nil conversion for clarity and consistency
- All tests still passing

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 14:01:20 +00:00
copilot-swe-agent[bot]
2a2d351ad4 Fix API returning null with 200 code when no records found
- Modified handleRead to always return empty array [] instead of null when no ID provided
- Added X-No-Data-Found header when result count is 0
- Updated normalizeResultArray to keep empty arrays as arrays instead of converting to empty objects
- Updated sendFormattedResponse and sendResponseWithOptions to handle empty data properly
- All responses now return 200 OK instead of 206 Partial Content when no data found
- Added comprehensive tests to verify the fix

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 13:57:15 +00:00
copilot-swe-agent[bot]
e918c49b84 Initial plan 2025-12-30 13:51:07 +00:00
Hein Puth (Warkanum)
41e4956510 Merge pull request #12 from bitechdev/copilot/fix-prefix-event-issue
[WIP] Fix prefix addition in where queries and xfiles options
2025-12-30 15:38:35 +02:00
copilot-swe-agent[bot]
8e8c3c6de6 Refactor: Extract common logic from stripOuterParentheses functions
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 13:36:29 +00:00
copilot-swe-agent[bot]
aa9b7312f6 Fix AddTablePrefixToColumns to handle parenthesized AND conditions correctly
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 13:31:18 +00:00
copilot-swe-agent[bot]
dca43b0e05 Initial analysis: identified bug in AddTablePrefixToColumns
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 13:26:37 +00:00
copilot-swe-agent[bot]
6f368bbce5 Initial plan 2025-12-30 13:18:17 +00:00
Hein Puth (Warkanum)
8704cee941 Merge pull request #9 from bitechdev/websocketspec
feature: Websocketspec and mqtt spec
2025-12-30 15:02:59 +02:00
Hein Puth (Warkanum)
4ce5afe0ac Merge pull request #10 from bitechdev/copilot/sub-pr-9
Add WebSocketSpec and MQTTSpec real-time protocol implementations
2025-12-30 14:50:35 +02:00
copilot-swe-agent[bot]
7b98ea2145 Initial plan 2025-12-30 12:41:53 +00:00
Hein
897cb2ae0d fix: liniting issues and events dev 2025-12-30 14:40:45 +02:00
Hein
01420e6b63 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into websocketspec 2025-12-30 14:13:52 +02:00
Hein Puth (Warkanum)
645907d355 Merge pull request #5 from bitechdev/server
feature: Server Manager
2025-12-30 14:13:23 +02:00
Hein
e81d7b48cc feature: mqtt support 2025-12-30 14:12:36 +02:00
Hein
8f5a725a09 Bugfix with xfiles 2025-12-30 14:12:07 +02:00
Hein Puth (Warkanum)
3d5d7b788e Merge pull request #8 from bitechdev/copilot/sub-pr-5
Fix impossible type assertion in Remove method
2025-12-30 14:04:08 +02:00
copilot-swe-agent[bot]
eaecef686e Fix type assertion error in Remove method
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:44:56 +00:00
copilot-swe-agent[bot]
e0d21b17ec Initial plan 2025-12-30 11:38:31 +00:00
Hein Puth (Warkanum)
7e1718e864 Merge pull request #7 from bitechdev/copilot/sub-pr-5-again
Fix recover() not working in CatchPanic functions
2025-12-30 13:29:36 +02:00
Hein Puth (Warkanum)
16d416030e Merge pull request #6 from bitechdev/copilot/sub-pr-5
Implement persistent certificate storage with reuse for self-signed SSL
2025-12-30 13:27:50 +02:00
Hein
bf8500714a Websocket spec fixes 2025-12-30 13:25:16 +02:00
copilot-swe-agent[bot]
4f8edd6469 Add security improvements and race condition protection
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:14:59 +00:00
copilot-swe-agent[bot]
ccf8522f88 Refactor: Use persistent cert storage with reuse logic
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:12:21 +00:00
copilot-swe-agent[bot]
92a83e9cc6 Final update
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:09:06 +00:00
copilot-swe-agent[bot]
4cb35a78b0 Improve CatchPanicCallback: extract context early and clarify example
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:07:46 +00:00
copilot-swe-agent[bot]
e10e2e1c27 Fix recover() usage in CatchPanic functions by returning deferred function
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:06:43 +00:00
copilot-swe-agent[bot]
64f56325d4 Final verification and cleanup
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:03:01 +00:00
copilot-swe-agent[bot]
5e6032c91d Initial plan 2025-12-30 11:02:05 +00:00
Hein Puth (Warkanum)
bc2fdc143b Update pkg/logger/logger.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-30 13:00:56 +02:00
copilot-swe-agent[bot]
267e84fd84 Implement cleanup for temporary certificate directories
Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-30 11:00:45 +00:00
Hein Puth (Warkanum)
8adc386863 Update pkg/server/manager.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-30 12:58:38 +02:00
Hein Puth (Warkanum)
feb023ec48 Update pkg/server/tls.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-30 12:57:55 +02:00
Hein Puth (Warkanum)
de50141a04 Update pkg/server/manager.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-30 12:57:35 +02:00
copilot-swe-agent[bot]
c226dc349f Initial plan 2025-12-30 10:56:43 +00:00
Hein
d4a6f9c4c2 Better server manager 2025-12-29 17:19:16 +02:00
8f83e8fdc1 Merge branch 'main' of github.com:bitechdev/ResolveSpec into server 2025-12-28 09:07:05 +02:00
Hein
90df4a157c Socket spec tests 2025-12-23 17:27:48 +02:00
Hein
2dd404af96 Updated to websockspec 2025-12-23 17:27:29 +02:00
Hein
17c472b206 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into websocketspec 2025-12-23 15:23:36 +02:00
Hein
ed67caf055 fix: reasheadspec customsql calls AddTablePrefixToColumns
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m42s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m6s
Build , Vet Test, and Lint / Lint Code (push) Failing after -25m37s
Build , Vet Test, and Lint / Build (push) Successful in -25m35s
Tests / Unit Tests (push) Failing after -25m50s
Tests / Integration Tests (push) Failing after -25m59s
2025-12-23 14:17:02 +02:00
4d1b8b6982 Work on server 2025-12-20 10:42:51 +02:00
Hein
63ed62a9a3 fix: Stupid logic error.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m39s
Build , Vet Test, and Lint / Build (push) Successful in -25m47s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m6s
Tests / Unit Tests (push) Failing after -26m5s
Tests / Integration Tests (push) Failing after -26m5s
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:52:34 +02:00
Hein
0525323a47 Fixed tests failing due to reponse header status
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:50:16 +02:00
Hein Puth (Warkanum)
c3443f702e Merge pull request #4 from bitechdev/fix-dockers
Fixed Attempt to Fix Docker / Podman
2025-12-19 16:42:38 +02:00
Hein
45c463c117 Fixed Attempt to Fix Docker / Podman
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:42:01 +02:00
Hein
84d673ce14 Added OpenAPI UI Routes
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:32:14 +02:00
Hein
02fbdbd651 Cache package is pure infrastructure. Cache invalidates on create/delete from the API
Some checks failed
Tests / Integration Tests (push) Failing after 9s
Build , Vet Test, and Lint / Lint Code (push) Successful in 8m13s
Build , Vet Test, and Lint / Build (push) Successful in -24m36s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m6s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -24m33s
Tests / Unit Tests (push) Failing after -25m39s
2025-12-18 16:30:38 +02:00
Hein
97988e3b5e Updated bun version 2025-12-18 15:54:00 +02:00
Hein
c9838ad9d2 Bun bugfix 2025-12-18 15:22:58 +02:00
Hein
c5c0608f63 StatusPartialContent is better since we need to result to see. 2025-12-18 14:48:14 +02:00
Hein
39c3f05d21 StatusNoContent for zero length data 2025-12-18 13:34:07 +02:00
Hein
4ecd1ac17e Fixed to StatusNoContent 2025-12-18 13:20:39 +02:00
Hein
2b1aea0338 Fix null interface issue and added partial content response when content is empty 2025-12-18 13:19:57 +02:00
Hein
1e749efeb3 Fixes for not found records 2025-12-18 13:08:26 +02:00
Hein
1b2b0d8f0b Prototype for websockspec 2025-12-12 16:14:47 +02:00
130 changed files with 28615 additions and 2167 deletions

View File

@@ -17,11 +17,13 @@ jobs:
- name: Run unit tests
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
- name: Generate coverage report
continue-on-error: true
run: |
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
- name: Upload coverage
uses: actions/upload-artifact@v5
continue-on-error: true
with:
name: coverage-report
path: coverage.html
@@ -55,27 +57,34 @@ jobs:
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
- name: Run resolvespec integration tests
continue-on-error: true
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
- name: Run restheadspec integration tests
continue-on-error: true
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
- name: Generate integration coverage
continue-on-error: true
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: |
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
- name: Upload resolvespec integration coverage
uses: actions/upload-artifact@v5
continue-on-error: true
with:
name: resolvespec-integration-coverage-report
path: coverage-resolvespec-integration.html
- name: Upload restheadspec integration coverage
uses: actions/upload-artifact@v5
continue-on-error: true
with:
name: integration-coverage-restheadspec-report
path: coverage-restheadspec-integration

1
.gitignore vendored
View File

@@ -25,3 +25,4 @@ go.work.sum
.env
bin/
test.db
/testserver

View File

@@ -52,5 +52,9 @@
"upgrade_dependency": true,
"vendor": true
}
}
}
},
"conventionalCommits.scopes": [
"spectypes",
"dbmanager"
]
}

86
LICENSE
View File

@@ -1,21 +1,73 @@
MIT License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2025
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. Definitions.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
Copyright 2025 wdevs
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -13,10 +13,63 @@ test-integration:
# Run all tests (unit + integration)
test: test-unit test-integration
release-version: ## Create and push a release with specific version (use: make release-version VERSION=v1.2.3 or make release-version to auto-increment)
@if [ -z "$(VERSION)" ]; then \
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0"); \
echo "No VERSION specified. Last version: $$latest_tag"; \
version_num=$$(echo "$$latest_tag" | sed 's/^v//'); \
major=$$(echo "$$version_num" | cut -d. -f1); \
minor=$$(echo "$$version_num" | cut -d. -f2); \
patch=$$(echo "$$version_num" | cut -d. -f3); \
new_patch=$$((patch + 1)); \
version="v$$major.$$minor.$$new_patch"; \
echo "Auto-incrementing to: $$version"; \
else \
version="$(VERSION)"; \
if ! echo "$$version" | grep -q "^v"; then \
version="v$$version"; \
fi; \
fi; \
echo "Creating release: $$version"; \
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo ""); \
if [ -z "$$latest_tag" ]; then \
commit_logs=$$(git log --pretty=format:"- %s" --no-merges); \
else \
commit_logs=$$(git log "$${latest_tag}..HEAD" --pretty=format:"- %s" --no-merges); \
fi; \
if [ -z "$$commit_logs" ]; then \
tag_message="Release $$version"; \
else \
tag_message="Release $$version\n\n$$commit_logs"; \
fi; \
git tag -a "$$version" -m "$$tag_message"; \
git push origin "$$version"; \
echo "Tag $$version created and pushed to remote repository."
lint: ## Run linter
@echo "Running linter..."
@if command -v golangci-lint > /dev/null; then \
golangci-lint run --config=.golangci.json; \
else \
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
exit 1; \
fi
lintfix: ## Run linter
@echo "Running linter..."
@if command -v golangci-lint > /dev/null; then \
golangci-lint run --config=.golangci.json --fix; \
else \
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
exit 1; \
fi
# Start PostgreSQL for integration tests
docker-up:
@echo "Starting PostgreSQL container..."
@docker-compose up -d postgres-test
@podman compose up -d postgres-test
@echo "Waiting for PostgreSQL to be ready..."
@sleep 5
@echo "PostgreSQL is ready!"
@@ -24,12 +77,12 @@ docker-up:
# Stop PostgreSQL container
docker-down:
@echo "Stopping PostgreSQL container..."
@docker-compose down
@podman compose down
# Clean up Docker volumes and test data
clean:
@echo "Cleaning up..."
@docker-compose down -v
@podman compose down -v
@echo "Cleanup complete!"
# Run integration tests with Docker (full workflow)

989
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,14 @@
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/server"
@@ -15,7 +17,6 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
gormlog "gorm.io/gorm/logger"
)
@@ -40,12 +41,14 @@ func main() {
logger.Info("ResolveSpec test server starting")
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
// Initialize database
db, err := initDB(cfg)
// Initialize database manager
ctx := context.Background()
dbMgr, db, err := initDB(ctx, cfg)
if err != nil {
logger.Error("Failed to initialize database: %+v", err)
os.Exit(1)
}
defer dbMgr.Close()
// Create router
r := mux.NewRouter()
@@ -67,9 +70,36 @@ func main() {
// Setup routes using new SetupMuxRoutes function (without authentication)
resolvespec.SetupMuxRoutes(r, handler, nil)
// Create graceful server with configuration
srv := server.NewGracefulServer(server.Config{
Addr: cfg.Server.Addr,
// Create server manager
mgr := server.NewManager()
// Parse host and port from addr
host := ""
port := 8080
if cfg.Server.Addr != "" {
// Parse addr (format: ":8080" or "localhost:8080")
if cfg.Server.Addr[0] == ':' {
// Just port
_, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port)
if err != nil {
logger.Error("Invalid server address: %s", cfg.Server.Addr)
os.Exit(1)
}
} else {
// Host and port
_, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port)
if err != nil {
logger.Error("Invalid server address: %s", cfg.Server.Addr)
os.Exit(1)
}
}
}
// Add server instance
_, err = mgr.Add(server.Config{
Name: "api",
Host: host,
Port: port,
Handler: r,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
DrainTimeout: cfg.Server.DrainTimeout,
@@ -77,16 +107,20 @@ func main() {
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
})
if err != nil {
logger.Error("Failed to add server: %v", err)
os.Exit(1)
}
// Start server with graceful shutdown
logger.Info("Starting server on %s", cfg.Server.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Server failed to start: %v", err)
if err := mgr.ServeWithGracefulShutdown(); err != nil {
logger.Error("Server failed: %v", err)
os.Exit(1)
}
}
func initDB(cfg *config.Config) (*gorm.DB, error) {
func initDB(ctx context.Context, cfg *config.Config) (dbmanager.Manager, *gorm.DB, error) {
// Configure GORM logger based on config
logLevel := gormlog.Info
if !cfg.Logger.Dev {
@@ -104,25 +138,41 @@ func initDB(cfg *config.Config) (*gorm.DB, error) {
},
)
// Use database URL from config if available, otherwise use default SQLite
dbURL := cfg.Database.URL
if dbURL == "" {
dbURL = "test.db"
// Create database manager from config
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
if err != nil {
return nil, nil, fmt.Errorf("failed to create database manager: %w", err)
}
// Create SQLite database
db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
if err != nil {
return nil, err
// Connect all databases
if err := mgr.Connect(ctx); err != nil {
return nil, nil, fmt.Errorf("failed to connect databases: %w", err)
}
// Get default connection
conn, err := mgr.GetDefault()
if err != nil {
mgr.Close()
return nil, nil, fmt.Errorf("failed to get default connection: %w", err)
}
// Get GORM database
gormDB, err := conn.GORM()
if err != nil {
mgr.Close()
return nil, nil, fmt.Errorf("failed to get GORM database: %w", err)
}
// Update GORM logger
gormDB.Logger = newLogger
modelList := testmodels.GetTestModels()
// Auto migrate schemas
err = db.AutoMigrate(modelList...)
if err != nil {
return nil, err
if err := gormDB.AutoMigrate(modelList...); err != nil {
mgr.Close()
return nil, nil, fmt.Errorf("failed to auto migrate: %w", err)
}
return db, nil
return mgr, gormDB, nil
}

View File

@@ -37,5 +37,25 @@ cors:
tracing:
enabled: false
database:
url: "" # Empty means use default SQLite (test.db)
# Database Manager Configuration
dbmanager:
default_connection: "primary"
max_open_conns: 25
max_idle_conns: 5
conn_max_lifetime: 30m
conn_max_idle_time: 5m
retry_attempts: 3
retry_delay: 1s
health_check_interval: 30s
enable_auto_reconnect: true
connections:
primary:
name: "primary"
type: "sqlite"
filepath: "test.db"
default_orm: "gorm"
enable_logging: true
enable_metrics: false
connect_timeout: 10s
query_timeout: 30s

94
go.mod
View File

@@ -7,91 +7,155 @@ toolchain go1.24.6
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/eclipse/paho.mqtt.golang v1.5.1
github.com/getsentry/sentry-go v0.40.0
github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.6.0
github.com/klauspost/compress v1.18.0
github.com/mattn/go-sqlite3 v1.14.32
github.com/microsoft/go-mssqldb v1.9.5
github.com/mochi-mqtt/server/v2 v2.7.9
github.com/nats-io/nats.go v1.48.0
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/uptrace/bun v1.2.15
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
github.com/uptrace/bun/driver/sqliteshim v1.2.15
github.com/uptrace/bun v1.2.16
github.com/uptrace/bun/dialect/mssqldialect v1.2.16
github.com/uptrace/bun/dialect/pgdialect v1.2.16
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
github.com/uptrace/bun/driver/sqliteshim v1.2.16
github.com/uptrace/bunrouter v1.0.23
go.mongodb.org/mongo-driver v1.17.6
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.43.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.25.12
gorm.io/driver/sqlite v1.6.0
gorm.io/driver/sqlserver v1.6.3
gorm.io/gorm v1.30.0
)
require (
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
github.com/moby/sys/user v0.4.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/nats-io/nkeys v0.4.11 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/xid v1.4.0 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.30.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.66.3 // indirect
modernc.org/libc v1.67.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.38.0 // indirect
modernc.org/sqlite v1.40.1 // indirect
)
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17

382
go.sum
View File

@@ -1,5 +1,37 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
@@ -8,17 +40,48 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE=
github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
@@ -36,20 +99,43 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -58,6 +144,14 @@ github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo=
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -65,28 +159,78 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI=
github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc=
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
@@ -101,10 +245,20 @@ github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5i
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
@@ -116,12 +270,24 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -131,28 +297,53 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
github.com/uptrace/bun/dialect/mssqldialect v1.2.16 h1:rKv0cKPNBviXadB/+2Y/UedA/c1JnwGzUWZkdN5FdSQ=
github.com/uptrace/bun/dialect/mssqldialect v1.2.16/go.mod h1:J5U7tGKWDsx2Q7MwDZF2417jCdpD6yD/ZMFJcCR80bk=
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
@@ -173,25 +364,128 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
@@ -205,25 +499,37 @@ google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXn
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -232,8 +538,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -57,11 +57,31 @@ func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time
return c.provider.Set(ctx, key, value, ttl)
}
// SetWithTags serializes and stores a value in the cache with the specified TTL and tags.
func (c *Cache) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags []string) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.SetWithTags(ctx, key, data, ttl, tags)
}
// SetBytesWithTags stores raw bytes in the cache with the specified TTL and tags.
func (c *Cache) SetBytesWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
return c.provider.SetWithTags(ctx, key, value, ttl, tags)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByTag removes all keys associated with the given tag.
func (c *Cache) DeleteByTag(ctx context.Context, tag string) error {
return c.provider.DeleteByTag(ctx, tag)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)

View File

@@ -15,9 +15,17 @@ type Provider interface {
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// SetWithTags stores a value in the cache with the specified TTL and tags.
// Tags can be used to invalidate groups of related keys.
// If ttl is 0, the item never expires.
SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByTag removes all keys associated with the given tag.
DeleteByTag(ctx context.Context, tag string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error

View File

@@ -2,6 +2,7 @@ package cache
import (
"context"
"encoding/json"
"fmt"
"time"
@@ -97,8 +98,115 @@ func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, tt
return m.client.Set(item)
}
// SetWithTags stores a value in the cache with the specified TTL and tags.
// Note: Tag support in Memcache is limited and less efficient than Redis.
func (m *MemcacheProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
expiration := int32(ttl.Seconds())
// Set the main value
item := &memcache.Item{
Key: key,
Value: value,
Expiration: expiration,
}
if err := m.client.Set(item); err != nil {
return err
}
// Store tags for this key
if len(tags) > 0 {
tagsData, err := json.Marshal(tags)
if err != nil {
return fmt.Errorf("failed to marshal tags: %w", err)
}
tagsItem := &memcache.Item{
Key: fmt.Sprintf("cache:tags:%s", key),
Value: tagsData,
Expiration: expiration,
}
if err := m.client.Set(tagsItem); err != nil {
return err
}
// Add key to each tag's key list
for _, tag := range tags {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
// Get existing keys for this tag
var keys []string
if item, err := m.client.Get(tagKey); err == nil {
_ = json.Unmarshal(item.Value, &keys)
}
// Add current key if not already present
found := false
for _, k := range keys {
if k == key {
found = true
break
}
}
if !found {
keys = append(keys, key)
}
// Store updated key list
keysData, err := json.Marshal(keys)
if err != nil {
continue
}
tagItem := &memcache.Item{
Key: tagKey,
Value: keysData,
Expiration: expiration + 3600, // Give tag lists longer TTL
}
_ = m.client.Set(tagItem)
}
}
return nil
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
// Get tags for this key
tagsKey := fmt.Sprintf("cache:tags:%s", key)
if item, err := m.client.Get(tagsKey); err == nil {
var tags []string
if err := json.Unmarshal(item.Value, &tags); err == nil {
// Remove key from each tag's key list
for _, tag := range tags {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
if tagItem, err := m.client.Get(tagKey); err == nil {
var keys []string
if err := json.Unmarshal(tagItem.Value, &keys); err == nil {
// Remove current key from the list
newKeys := make([]string, 0, len(keys))
for _, k := range keys {
if k != key {
newKeys = append(newKeys, k)
}
}
// Update the tag's key list
if keysData, err := json.Marshal(newKeys); err == nil {
tagItem.Value = keysData
_ = m.client.Set(tagItem)
}
}
}
}
}
// Delete the tags key
_ = m.client.Delete(tagsKey)
}
// Delete the actual key
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
@@ -106,6 +214,38 @@ func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
return err
}
// DeleteByTag removes all keys associated with the given tag.
func (m *MemcacheProvider) DeleteByTag(ctx context.Context, tag string) error {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
// Get all keys associated with this tag
item, err := m.client.Get(tagKey)
if err == memcache.ErrCacheMiss {
return nil
}
if err != nil {
return err
}
var keys []string
if err := json.Unmarshal(item.Value, &keys); err != nil {
return fmt.Errorf("failed to unmarshal tag keys: %w", err)
}
// Delete all keys
for _, key := range keys {
_ = m.client.Delete(key)
// Also delete the tags key for this cache key
tagsKey := fmt.Sprintf("cache:tags:%s", key)
_ = m.client.Delete(tagsKey)
}
// Delete the tag key itself
_ = m.client.Delete(tagKey)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.

View File

@@ -15,6 +15,7 @@ type memoryItem struct {
Expiration time.Time
LastAccess time.Time
HitCount int64
Tags []string
}
// isExpired checks if the item has expired.
@@ -27,11 +28,12 @@ func (m *memoryItem) isExpired() bool {
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits atomic.Int64
misses atomic.Int64
mu sync.RWMutex
items map[string]*memoryItem
tagToKeys map[string]map[string]struct{} // tag -> set of keys
options *Options
hits atomic.Int64
misses atomic.Int64
}
// NewMemoryProvider creates a new in-memory cache provider.
@@ -44,8 +46,9 @@ func NewMemoryProvider(opts *Options) *MemoryProvider {
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
items: make(map[string]*memoryItem),
tagToKeys: make(map[string]map[string]struct{}),
options: opts,
}
}
@@ -114,15 +117,116 @@ func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl
return nil
}
// SetWithTags stores a value in the cache with the specified TTL and tags.
func (m *MemoryProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
// Remove old tag associations if key exists
if oldItem, exists := m.items[key]; exists {
for _, tag := range oldItem.Tags {
if keySet, ok := m.tagToKeys[tag]; ok {
delete(keySet, key)
if len(keySet) == 0 {
delete(m.tagToKeys, tag)
}
}
}
}
// Store the item
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
Tags: tags,
}
// Add new tag associations
for _, tag := range tags {
if m.tagToKeys[tag] == nil {
m.tagToKeys[tag] = make(map[string]struct{})
}
m.tagToKeys[tag][key] = struct{}{}
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
// Remove tag associations
if item, exists := m.items[key]; exists {
for _, tag := range item.Tags {
if keySet, ok := m.tagToKeys[tag]; ok {
delete(keySet, key)
if len(keySet) == 0 {
delete(m.tagToKeys, tag)
}
}
}
}
delete(m.items, key)
return nil
}
// DeleteByTag removes all keys associated with the given tag.
func (m *MemoryProvider) DeleteByTag(ctx context.Context, tag string) error {
m.mu.Lock()
defer m.mu.Unlock()
// Get all keys associated with this tag
keySet, exists := m.tagToKeys[tag]
if !exists {
return nil // No keys with this tag
}
// Delete all items with this tag
for key := range keySet {
if item, ok := m.items[key]; ok {
// Remove this tag from the item's tag list
newTags := make([]string, 0, len(item.Tags))
for _, t := range item.Tags {
if t != tag {
newTags = append(newTags, t)
}
}
// If item has no more tags, delete it
// Otherwise update its tags
if len(newTags) == 0 {
delete(m.items, key)
} else {
item.Tags = newTags
}
}
}
// Remove the tag mapping
delete(m.tagToKeys, tag)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()

View File

@@ -103,9 +103,93 @@ func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl t
return r.client.Set(ctx, key, value, ttl).Err()
}
// SetWithTags stores a value in the cache with the specified TTL and tags.
func (r *RedisProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
pipe := r.client.Pipeline()
// Set the value
pipe.Set(ctx, key, value, ttl)
// Add key to each tag's set
for _, tag := range tags {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
pipe.SAdd(ctx, tagKey, key)
// Set expiration on tag set (longer than cache items to ensure cleanup)
if ttl > 0 {
pipe.Expire(ctx, tagKey, ttl+time.Hour)
}
}
// Store tags for this key for later cleanup
if len(tags) > 0 {
tagsKey := fmt.Sprintf("cache:tags:%s", key)
pipe.SAdd(ctx, tagsKey, tags)
if ttl > 0 {
pipe.Expire(ctx, tagsKey, ttl)
}
}
_, err := pipe.Exec(ctx)
return err
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
pipe := r.client.Pipeline()
// Get tags for this key
tagsKey := fmt.Sprintf("cache:tags:%s", key)
tags, err := r.client.SMembers(ctx, tagsKey).Result()
if err == nil && len(tags) > 0 {
// Remove key from each tag set
for _, tag := range tags {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
pipe.SRem(ctx, tagKey, key)
}
// Delete the tags key
pipe.Del(ctx, tagsKey)
}
// Delete the actual key
pipe.Del(ctx, key)
_, err = pipe.Exec(ctx)
return err
}
// DeleteByTag removes all keys associated with the given tag.
func (r *RedisProvider) DeleteByTag(ctx context.Context, tag string) error {
tagKey := fmt.Sprintf("cache:tag:%s", tag)
// Get all keys associated with this tag
keys, err := r.client.SMembers(ctx, tagKey).Result()
if err != nil {
return err
}
if len(keys) == 0 {
return nil
}
pipe := r.client.Pipeline()
// Delete all keys and their tag associations
for _, key := range keys {
pipe.Del(ctx, key)
// Also delete the tags key for this cache key
tagsKey := fmt.Sprintf("cache:tags:%s", key)
pipe.Del(ctx, tagsKey)
}
// Delete the tag set itself
pipe.Del(ctx, tagKey)
_, err = pipe.Exec(ctx)
return err
}
// DeleteByPattern removes all keys matching the pattern.

View File

@@ -1,151 +0,0 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@@ -1,7 +1,16 @@
package database
import (
"database/sql"
"strings"
"github.com/uptrace/bun/dialect/mssqldialect"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/dialect/sqlitedialect"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
)
// parseTableName splits a table name that may contain schema into separate schema and table
@@ -14,3 +23,39 @@ func parseTableName(fullTableName string) (schema, table string) {
}
return "", fullTableName
}
// GetPostgresDialect returns a Bun PostgreSQL dialect
func GetPostgresDialect() *pgdialect.Dialect {
return pgdialect.New()
}
// GetSQLiteDialect returns a Bun SQLite dialect
func GetSQLiteDialect() *sqlitedialect.Dialect {
return sqlitedialect.New()
}
// GetMSSQLDialect returns a Bun MSSQL dialect
func GetMSSQLDialect() *mssqldialect.Dialect {
return mssqldialect.New()
}
// GetPostgresDialector returns a GORM PostgreSQL dialector
func GetPostgresDialector(db *sql.DB) gorm.Dialector {
return postgres.New(postgres.Config{
Conn: db,
})
}
// GetSQLiteDialector returns a GORM SQLite dialector
func GetSQLiteDialector(db *sql.DB) gorm.Dialector {
return sqlite.Dialector{
Conn: db,
}
}
// GetMSSQLDialector returns a GORM MSSQL dialector
func GetMSSQLDialector(db *sql.DB) gorm.Dialector {
return sqlserver.New(sqlserver.Config{
Conn: db,
})
}

View File

@@ -208,21 +208,9 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
}
}
}
} else if tableName != "" && !hasTablePrefix(condToCheck) {
// If tableName is provided and the condition DOESN'T have a table prefix,
// qualify unambiguous column references to prevent "ambiguous column" errors
// when there are multiple joins on the same table (e.g., recursive preloads)
columnName := extractUnqualifiedColumnName(condToCheck)
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
// Qualify the column with the table name
// Be careful to only replace the column name, not other occurrences of the string
oldRef := columnName
newRef := tableName + "." + columnName
// Use word boundary matching to avoid replacing partial matches
cond = qualifyColumnInCondition(cond, oldRef, newRef)
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
}
}
// Note: We no longer add prefixes to unqualified columns here.
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
validConditions = append(validConditions, cond)
}
@@ -246,35 +234,52 @@ func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
stripped, wasStripped := stripOneMatchingOuterParen(s)
if !wasStripped {
return s
}
s = stripped
}
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
// stripOneOuterParentheses removes only one level of matching outer parentheses from a string
// Unlike stripOuterParentheses, this only strips once, preserving nested parentheses
func stripOneOuterParentheses(s string) string {
stripped, _ := stripOneMatchingOuterParen(strings.TrimSpace(s))
return stripped
}
// stripOneMatchingOuterParen is a helper that strips one matching pair of outer parentheses
// Returns the stripped string and a boolean indicating if stripping occurred
func stripOneMatchingOuterParen(s string) (string, bool) {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s, false
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s, false
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
if !matched {
return s, false
}
// Strip the outer parentheses
return strings.TrimSpace(s[1 : len(s)-1]), true
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
@@ -498,9 +503,10 @@ func extractTableAndColumn(cond string) (table string, column string) {
return "", ""
}
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
// Unused: extractUnqualifiedColumnName extracts the column name from an unqualified condition
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
// "status = 'active'" returns "status"
// nolint:unused
func extractUnqualifiedColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
@@ -633,3 +639,173 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
}
return validColumns[strings.ToLower(columnName)]
}
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
// This function only prefixes simple column references and skips:
// - Columns already having a table prefix (containing a dot)
// - Columns inside function calls or expressions (inside parentheses)
// - Columns inside subqueries
// - Columns that don't exist in the table (validation via model registry)
//
// Examples:
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
// - "users.status = 'active'" -> unchanged (already has prefix)
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
//
// Parameters:
// - where: The WHERE clause to process
// - tableName: The table name to use as prefix
//
// Returns:
// - The WHERE clause with table prefixes added to appropriate and valid columns
func AddTablePrefixToColumns(where string, tableName string) string {
if where == "" || tableName == "" {
return where
}
where = strings.TrimSpace(where)
// Get valid columns from the model registry for validation
validColumns := getValidColumnsForTable(tableName)
// Split by AND to handle multiple conditions (parenthesis-aware)
conditions := splitByAND(where)
prefixedConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Process this condition to add table prefix if appropriate
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
prefixedConditions = append(prefixedConditions, processedCond)
}
if len(prefixedConditions) == 0 {
return ""
}
return strings.Join(prefixedConditions, " AND ")
}
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
// Returns the condition unchanged if:
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
// - The column reference is inside a function call
// - The column already has a table prefix
// - No valid column reference is found
// - The column doesn't exist in the table (when validColumns is provided)
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
// Strip one level of outer grouping parentheses to get to the actual condition
strippedCond := stripOneOuterParentheses(cond)
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
return cond
}
// After stripping outer parentheses, check if there are multiple AND-separated conditions
// at the top level. If so, split and process each separately to avoid incorrectly
// treating "true AND status" as a single column name.
subConditions := splitByAND(strippedCond)
if len(subConditions) > 1 {
// Multiple conditions found - process each separately
logger.Debug("Found %d sub-conditions after stripping parentheses, processing separately", len(subConditions))
processedConditions := make([]string, 0, len(subConditions))
for _, subCond := range subConditions {
// Recursively process each sub-condition
processed := addPrefixToSingleCondition(subCond, tableName, validColumns)
processedConditions = append(processedConditions, processed)
}
result := strings.Join(processedConditions, " AND ")
// Preserve original outer parentheses if they existed
if cond != strippedCond {
result = "(" + result + ")"
}
return result
}
// If we stripped parentheses and still have more parentheses, recursively process
if cond != strippedCond && strings.HasPrefix(strippedCond, "(") && strings.HasSuffix(strippedCond, ")") {
// Recursively handle nested parentheses
processed := addPrefixToSingleCondition(strippedCond, tableName, validColumns)
return "(" + processed + ")"
}
// Extract the left side of the comparison (before the operator)
columnRef := extractLeftSideOfComparison(strippedCond)
if columnRef == "" {
return cond
}
// Skip if it already has a prefix (contains a dot)
if strings.Contains(columnRef, ".") {
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
return cond
}
// Skip if it's a function call or expression (contains parentheses)
if strings.Contains(columnRef, "(") {
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
return cond
}
// Validate that the column exists in the table (if we have column info)
if !isValidColumn(columnRef, validColumns) {
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
return cond
}
// It's a simple unqualified column reference that exists in the table - add the table prefix
newRef := tableName + "." + columnRef
result := qualifyColumnInCondition(cond, columnRef, newRef)
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
return result
}
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
// This is used to identify the column reference that may need a table prefix.
//
// Examples:
// - "status = 'active'" returns "status"
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
// - "priority > 5" returns "priority"
//
// Returns empty string if no operator is found.
func extractLeftSideOfComparison(cond string) string {
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
// Find the first operator outside of parentheses and quotes
minIdx := -1
for _, op := range operators {
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
leftSide := strings.TrimSpace(cond[:minIdx])
// Remove any surrounding quotes
leftSide = strings.Trim(leftSide, "`\"'")
return leftSide
}
// No operator found - might be a boolean column
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef := strings.Trim(parts[0], "`\"'")
// Make sure it's not a SQL keyword
if !IsSQLKeyword(strings.ToLower(columnRef)) {
return columnRef
}
}
return ""
}

View File

@@ -138,7 +138,10 @@ func TestSanitizeWhereClause(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
// First add table prefixes to unqualified columns
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
// Then sanitize the where clause
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
@@ -348,6 +351,7 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
tableName string
options *RequestOptions
expected string
addPrefix bool
}{
{
name: "preload relation prefix is preserved",
@@ -416,15 +420,30 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
options: &RequestOptions{Preload: []PreloadOption{}},
expected: "users.status = 'active'",
},
{
name: "complex where clause with subquery and preload",
where: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (rid_parentmastertaskitem is null)`,
tableName: "mastertaskitem",
options: nil,
expected: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (mastertaskitem.rid_parentmastertaskitem is null)`,
addPrefix: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
prefixedWhere := tt.where
if tt.addPrefix {
// First add table prefixes to unqualified columns
prefixedWhere = AddTablePrefixToColumns(tt.where, tt.tableName)
}
// Then sanitize the where clause
if tt.options != nil {
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
result = SanitizeWhereClause(prefixedWhere, tt.tableName, tt.options)
} else {
result = SanitizeWhereClause(tt.where, tt.tableName)
result = SanitizeWhereClause(prefixedWhere, tt.tableName)
}
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
@@ -639,3 +658,76 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
})
}
}
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "Parentheses with true AND condition - should not prefix true",
where: "(true AND status = 'active')",
tableName: "mastertask",
expected: "(true AND mastertask.status = 'active')",
},
{
name: "Parentheses with multiple conditions including true",
where: "(true AND status = 'active' AND id > 5)",
tableName: "mastertask",
expected: "(true AND mastertask.status = 'active' AND mastertask.id > 5)",
},
{
name: "Nested parentheses with true",
where: "((true AND status = 'active'))",
tableName: "mastertask",
expected: "((true AND mastertask.status = 'active'))",
},
{
name: "Mixed: false AND valid conditions",
where: "(false AND name = 'test')",
tableName: "mastertask",
expected: "(false AND mastertask.name = 'test')",
},
{
name: "Mixed: null AND valid conditions",
where: "(null AND status = 'active')",
tableName: "mastertask",
expected: "(null AND mastertask.status = 'active')",
},
{
name: "Multiple true conditions in parentheses",
where: "(true AND true AND status = 'active')",
tableName: "mastertask",
expected: "(true AND true AND mastertask.status = 'active')",
},
{
name: "Simple true without parens - should not prefix",
where: "true",
tableName: "mastertask",
expected: "true",
},
{
name: "Simple condition without parens - should prefix",
where: "status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "Unregistered table with true - should not prefix true",
where: "(true AND status = 'active')",
tableName: "unregistered_table",
expected: "(true AND unregistered_table.status = 'active')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AddTablePrefixToColumns(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -11,8 +11,8 @@ type Config struct {
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
DBManager DBManagerConfig `mapstructure:"dbmanager"`
}
// ServerConfig holds server-related configuration
@@ -76,11 +76,6 @@ type CORSConfig struct {
MaxAge int `mapstructure:"max_age"`
}
// DatabaseConfig holds database configuration (primarily for testing)
type DatabaseConfig struct {
URL string `mapstructure:"url"`
}
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`

107
pkg/config/dbmanager.go Normal file
View File

@@ -0,0 +1,107 @@
package config
import (
"fmt"
"time"
)
// DBManagerConfig contains configuration for the database connection manager
type DBManagerConfig struct {
// DefaultConnection is the name of the default connection to use
DefaultConnection string `mapstructure:"default_connection"`
// Connections is a map of connection name to connection configuration
Connections map[string]DBConnectionConfig `mapstructure:"connections"`
// Global connection pool defaults
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
// Retry policy
RetryAttempts int `mapstructure:"retry_attempts"`
RetryDelay time.Duration `mapstructure:"retry_delay"`
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
// Health checks
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
}
// DBConnectionConfig defines configuration for a single database connection
type DBConnectionConfig struct {
// Name is the unique name of this connection
Name string `mapstructure:"name"`
// Type is the database type (postgres, sqlite, mssql, mongodb)
Type string `mapstructure:"type"`
// DSN is the complete Data Source Name / connection string
// If provided, this takes precedence over individual connection parameters
DSN string `mapstructure:"dsn"`
// Connection parameters (used if DSN is not provided)
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
// PostgreSQL/MSSQL specific
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
Schema string `mapstructure:"schema"` // Default schema
// SQLite specific
FilePath string `mapstructure:"filepath"`
// MongoDB specific
AuthSource string `mapstructure:"auth_source"`
ReplicaSet string `mapstructure:"replica_set"`
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
// Connection pool settings (overrides global defaults)
MaxOpenConns *int `mapstructure:"max_open_conns"`
MaxIdleConns *int `mapstructure:"max_idle_conns"`
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
// Timeouts
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
QueryTimeout time.Duration `mapstructure:"query_timeout"`
// Features
EnableTracing bool `mapstructure:"enable_tracing"`
EnableMetrics bool `mapstructure:"enable_metrics"`
EnableLogging bool `mapstructure:"enable_logging"`
// DefaultORM specifies which ORM to use for the Database() method
// Options: "bun", "gorm", "native"
DefaultORM string `mapstructure:"default_orm"`
// Tags for organization and filtering
Tags map[string]string `mapstructure:"tags"`
}
// ToManagerConfig converts config.DBManagerConfig to dbmanager.ManagerConfig
// This is used to avoid circular dependencies
func (c *DBManagerConfig) ToManagerConfig() interface{} {
// This will be implemented in the dbmanager package
// to convert from config types to dbmanager types
return c
}
// Validate validates the DBManager configuration
func (c *DBManagerConfig) Validate() error {
if len(c.Connections) == 0 {
return fmt.Errorf("at least one connection must be configured")
}
if c.DefaultConnection != "" {
if _, ok := c.Connections[c.DefaultConnection]; !ok {
return fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection)
}
}
return nil
}

531
pkg/dbmanager/README.md Normal file
View File

@@ -0,0 +1,531 @@
# Database Connection Manager (dbmanager)
A comprehensive database connection manager for Go that provides centralized management of multiple named database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
## Features
- **Multiple Named Connections**: Manage multiple database connections with names like `primary`, `analytics`, `cache-db`
- **Multi-Database Support**: PostgreSQL, SQLite, Microsoft SQL Server, and MongoDB
- **Multi-ORM Access**: Each SQL connection provides access through:
- **Bun ORM** - Modern, lightweight ORM
- **GORM** - Popular Go ORM
- **Native** - Standard library `*sql.DB`
- All three share the same underlying connection pool
- **Configuration-Driven**: YAML configuration with Viper integration
- **Production-Ready Features**:
- Automatic health checks and reconnection
- Prometheus metrics
- Connection pooling with configurable limits
- Retry logic with exponential backoff
- Graceful shutdown
- OpenTelemetry tracing support
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/dbmanager
```
## Quick Start
### 1. Configuration
Create a configuration file (e.g., `config.yaml`):
```yaml
dbmanager:
default_connection: "primary"
# Global connection pool defaults
max_open_conns: 25
max_idle_conns: 5
conn_max_lifetime: 30m
conn_max_idle_time: 5m
# Retry configuration
retry_attempts: 3
retry_delay: 1s
retry_max_delay: 10s
# Health checks
health_check_interval: 30s
enable_auto_reconnect: true
connections:
# Primary PostgreSQL connection
primary:
type: postgres
host: localhost
port: 5432
user: myuser
password: mypassword
database: myapp
sslmode: disable
default_orm: bun
enable_metrics: true
enable_tracing: true
enable_logging: true
# Read replica for analytics
analytics:
type: postgres
dsn: "postgres://readonly:pass@analytics:5432/analytics"
default_orm: bun
enable_metrics: true
# SQLite cache
cache-db:
type: sqlite
filepath: /var/lib/app/cache.db
max_open_conns: 1
# MongoDB for documents
documents:
type: mongodb
host: localhost
port: 27017
database: documents
user: mongouser
password: mongopass
auth_source: admin
enable_metrics: true
```
### 2. Initialize Manager
```go
package main
import (
"context"
"log"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
)
func main() {
// Load configuration
cfgMgr := config.NewManager()
if err := cfgMgr.Load(); err != nil {
log.Fatal(err)
}
cfg, _ := cfgMgr.GetConfig()
// Create database manager
mgr, err := dbmanager.NewManager(cfg.DBManager)
if err != nil {
log.Fatal(err)
}
defer mgr.Close()
// Connect all databases
ctx := context.Background()
if err := mgr.Connect(ctx); err != nil {
log.Fatal(err)
}
// Your application code here...
}
```
### 3. Use Database Connections
#### Get Default Database
```go
// Get the default database (as configured common.Database interface)
db, err := mgr.GetDefaultDatabase()
if err != nil {
log.Fatal(err)
}
// Use it with any query
var users []User
err = db.NewSelect().
Model(&users).
Where("active = ?", true).
Scan(ctx, &users)
```
#### Get Named Connection with Specific ORM
```go
// Get primary connection
primary, err := mgr.Get("primary")
if err != nil {
log.Fatal(err)
}
// Use with Bun
bunDB, err := primary.Bun()
if err != nil {
log.Fatal(err)
}
err = bunDB.NewSelect().Model(&users).Scan(ctx)
// Use with GORM (same underlying connection!)
gormDB, err := primary.GORM()
if err != nil {
log.Fatal(err)
}
gormDB.Where("active = ?", true).Find(&users)
// Use native *sql.DB
nativeDB, err := primary.Native()
if err != nil {
log.Fatal(err)
}
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
```
#### Use MongoDB
```go
// Get MongoDB connection
docs, err := mgr.Get("documents")
if err != nil {
log.Fatal(err)
}
mongoClient, err := docs.MongoDB()
if err != nil {
log.Fatal(err)
}
collection := mongoClient.Database("documents").Collection("articles")
// Use MongoDB driver...
```
#### Change Default Database
```go
// Switch to analytics database as default
err := mgr.SetDefaultDatabase("analytics")
if err != nil {
log.Fatal(err)
}
// Now GetDefaultDatabase() returns the analytics connection
db, _ := mgr.GetDefaultDatabase()
```
## Configuration Reference
### Manager Configuration
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `default_connection` | string | "" | Name of the default connection |
| `connections` | map | {} | Map of connection name to ConnectionConfig |
| `max_open_conns` | int | 25 | Global default for max open connections |
| `max_idle_conns` | int | 5 | Global default for max idle connections |
| `conn_max_lifetime` | duration | 30m | Global default for connection max lifetime |
| `conn_max_idle_time` | duration | 5m | Global default for connection max idle time |
| `retry_attempts` | int | 3 | Number of connection retry attempts |
| `retry_delay` | duration | 1s | Initial retry delay |
| `retry_max_delay` | duration | 10s | Maximum retry delay |
| `health_check_interval` | duration | 30s | Interval between health checks |
| `enable_auto_reconnect` | bool | true | Auto-reconnect on health check failure |
### Connection Configuration
| Field | Type | Description |
|-------|------|-------------|
| `name` | string | Unique connection name |
| `type` | string | Database type: `postgres`, `sqlite`, `mssql`, `mongodb` |
| `dsn` | string | Complete connection string (overrides individual params) |
| `host` | string | Database host |
| `port` | int | Database port |
| `user` | string | Username |
| `password` | string | Password |
| `database` | string | Database name |
| `sslmode` | string | SSL mode (postgres/mssql): `disable`, `require`, etc. |
| `schema` | string | Default schema (postgres/mssql) |
| `filepath` | string | File path (sqlite only) |
| `auth_source` | string | Auth source (mongodb) |
| `replica_set` | string | Replica set name (mongodb) |
| `read_preference` | string | Read preference (mongodb): `primary`, `secondary`, etc. |
| `max_open_conns` | int | Override global max open connections |
| `max_idle_conns` | int | Override global max idle connections |
| `conn_max_lifetime` | duration | Override global connection max lifetime |
| `conn_max_idle_time` | duration | Override global connection max idle time |
| `connect_timeout` | duration | Connection timeout (default: 10s) |
| `query_timeout` | duration | Query timeout (default: 30s) |
| `enable_tracing` | bool | Enable OpenTelemetry tracing |
| `enable_metrics` | bool | Enable Prometheus metrics |
| `enable_logging` | bool | Enable connection logging |
| `default_orm` | string | Default ORM for Database(): `bun`, `gorm`, `native` |
| `tags` | map[string]string | Custom tags for filtering/organization |
## Advanced Usage
### Health Checks
```go
// Manual health check
if err := mgr.HealthCheck(ctx); err != nil {
log.Printf("Health check failed: %v", err)
}
// Per-connection health check
primary, _ := mgr.Get("primary")
if err := primary.HealthCheck(ctx); err != nil {
log.Printf("Primary connection unhealthy: %v", err)
// Manual reconnect
if err := primary.Reconnect(ctx); err != nil {
log.Printf("Reconnection failed: %v", err)
}
}
```
### Connection Statistics
```go
// Get overall statistics
stats := mgr.Stats()
fmt.Printf("Total connections: %d\n", stats.TotalConnections)
fmt.Printf("Healthy: %d, Unhealthy: %d\n", stats.HealthyCount, stats.UnhealthyCount)
// Per-connection stats
for name, connStats := range stats.ConnectionStats {
fmt.Printf("%s: %d open, %d in use, %d idle\n",
name,
connStats.OpenConnections,
connStats.InUse,
connStats.Idle)
}
// Individual connection stats
primary, _ := mgr.Get("primary")
stats := primary.Stats()
fmt.Printf("Wait count: %d, Wait duration: %v\n",
stats.WaitCount,
stats.WaitDuration)
```
### Prometheus Metrics
The package automatically exports Prometheus metrics:
- `dbmanager_connections_total` - Total configured connections by type
- `dbmanager_connection_status` - Connection health status (1=healthy, 0=unhealthy)
- `dbmanager_connection_pool_size` - Connection pool statistics by state
- `dbmanager_connection_wait_count` - Times connections waited for availability
- `dbmanager_connection_wait_duration_seconds` - Total wait duration
- `dbmanager_health_check_duration_seconds` - Health check execution time
- `dbmanager_reconnect_attempts_total` - Reconnection attempts and results
- `dbmanager_connection_lifetime_closed_total` - Connections closed due to max lifetime
- `dbmanager_connection_idle_closed_total` - Connections closed due to max idle time
Metrics are automatically updated during health checks. To manually publish metrics:
```go
if mgr, ok := mgr.(*connectionManager); ok {
mgr.PublishMetrics()
}
```
## Architecture
### Single Connection Pool, Multiple ORMs
A key design principle is that Bun, GORM, and Native all wrap the **same underlying `*sql.DB`** connection pool:
```
┌─────────────────────────────────────┐
│ SQL Connection │
├─────────────────────────────────────┤
│ ┌─────────┐ ┌──────┐ ┌────────┐ │
│ │ Bun │ │ GORM │ │ Native │ │
│ └────┬────┘ └───┬──┘ └───┬────┘ │
│ │ │ │ │
│ └───────────┴─────────┘ │
│ *sql.DB │
│ (single pool) │
└─────────────────────────────────────┘
```
**Benefits:**
- No connection duplication
- Consistent pool limits across all ORMs
- Unified connection statistics
- Lower resource usage
### Provider Pattern
Each database type has a dedicated provider:
- **PostgresProvider** - Uses `pgx` driver
- **SQLiteProvider** - Uses `glebarez/sqlite` (pure Go)
- **MSSQLProvider** - Uses `go-mssqldb`
- **MongoProvider** - Uses official `mongo-driver`
Providers handle:
- Connection establishment with retry logic
- Health checking
- Connection statistics
- Connection cleanup
## Best Practices
1. **Use Named Connections**: Be explicit about which database you're accessing
```go
primary, _ := mgr.Get("primary") // Good
db, _ := mgr.GetDefaultDatabase() // Risky if default changes
```
2. **Configure Connection Pools**: Tune based on your workload
```yaml
connections:
primary:
max_open_conns: 100 # High traffic API
max_idle_conns: 25
analytics:
max_open_conns: 10 # Background analytics
max_idle_conns: 2
```
3. **Enable Health Checks**: Catch connection issues early
```yaml
health_check_interval: 30s
enable_auto_reconnect: true
```
4. **Use Appropriate ORM**: Choose based on your needs
- **Bun**: Modern, fast, type-safe - recommended for new code
- **GORM**: Mature, feature-rich - good for existing GORM code
- **Native**: Maximum control - use for performance-critical queries
5. **Monitor Metrics**: Watch connection pool utilization
- If `wait_count` is high, increase `max_open_conns`
- If `idle` is always high, decrease `max_idle_conns`
## Troubleshooting
### Connection Failures
If connections fail to establish:
1. Check configuration:
```bash
# Test connection manually
psql -h localhost -U myuser -d myapp
```
2. Enable logging:
```yaml
connections:
primary:
enable_logging: true
```
3. Check retry attempts:
```yaml
retry_attempts: 5 # Increase retries
retry_max_delay: 30s
```
### Pool Exhaustion
If you see "too many connections" errors:
1. Increase pool size:
```yaml
max_open_conns: 50 # Increase from default 25
```
2. Reduce connection lifetime:
```yaml
conn_max_lifetime: 15m # Recycle faster
```
3. Monitor wait stats:
```go
stats := primary.Stats()
if stats.WaitCount > 1000 {
log.Warn("High connection wait count")
}
```
### MongoDB vs SQL Confusion
MongoDB connections don't support SQL ORMs:
```go
docs, _ := mgr.Get("documents")
// ✓ Correct
mongoClient, _ := docs.MongoDB()
// ✗ Error: ErrNotSQLDatabase
bunDB, err := docs.Bun() // Won't work!
```
SQL connections don't support MongoDB:
```go
primary, _ := mgr.Get("primary")
// ✓ Correct
bunDB, _ := primary.Bun()
// ✗ Error: ErrNotMongoDB
mongoClient, err := primary.MongoDB() // Won't work!
```
## Migration Guide
### From Raw `database/sql`
Before:
```go
db, err := sql.Open("postgres", dsn)
defer db.Close()
rows, err := db.Query("SELECT * FROM users")
```
After:
```go
mgr, _ := dbmanager.NewManager(cfg.DBManager)
mgr.Connect(ctx)
defer mgr.Close()
primary, _ := mgr.Get("primary")
nativeDB, _ := primary.Native()
rows, err := nativeDB.Query("SELECT * FROM users")
```
### From Direct Bun/GORM
Before:
```go
sqldb, _ := sql.Open("pgx", dsn)
bunDB := bun.NewDB(sqldb, pgdialect.New())
var users []User
bunDB.NewSelect().Model(&users).Scan(ctx)
```
After:
```go
mgr, _ := dbmanager.NewManager(cfg.DBManager)
mgr.Connect(ctx)
primary, _ := mgr.Get("primary")
bunDB, _ := primary.Bun()
var users []User
bunDB.NewSelect().Model(&users).Scan(ctx)
```
## License
Same as the parent project.
## Contributing
Contributions are welcome! Please submit issues and pull requests to the main repository.

448
pkg/dbmanager/config.go Normal file
View File

@@ -0,0 +1,448 @@
package dbmanager
import (
"fmt"
"time"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// DatabaseType represents the type of database
type DatabaseType string
const (
// DatabaseTypePostgreSQL represents PostgreSQL database
DatabaseTypePostgreSQL DatabaseType = "postgres"
// DatabaseTypeSQLite represents SQLite database
DatabaseTypeSQLite DatabaseType = "sqlite"
// DatabaseTypeMSSQL represents Microsoft SQL Server database
DatabaseTypeMSSQL DatabaseType = "mssql"
// DatabaseTypeMongoDB represents MongoDB database
DatabaseTypeMongoDB DatabaseType = "mongodb"
)
// ORMType represents the ORM to use for database operations
type ORMType string
const (
// ORMTypeBun represents Bun ORM
ORMTypeBun ORMType = "bun"
// ORMTypeGORM represents GORM
ORMTypeGORM ORMType = "gorm"
// ORMTypeNative represents native database/sql
ORMTypeNative ORMType = "native"
)
// ManagerConfig contains configuration for the database connection manager
type ManagerConfig struct {
// DefaultConnection is the name of the default connection to use
DefaultConnection string `mapstructure:"default_connection"`
// Connections is a map of connection name to connection configuration
Connections map[string]ConnectionConfig `mapstructure:"connections"`
// Global connection pool defaults
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
// Retry policy
RetryAttempts int `mapstructure:"retry_attempts"`
RetryDelay time.Duration `mapstructure:"retry_delay"`
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
// Health checks
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
}
// ConnectionConfig defines configuration for a single database connection
type ConnectionConfig struct {
// Name is the unique name of this connection
Name string `mapstructure:"name"`
// Type is the database type (postgres, sqlite, mssql, mongodb)
Type DatabaseType `mapstructure:"type"`
// DSN is the complete Data Source Name / connection string
// If provided, this takes precedence over individual connection parameters
DSN string `mapstructure:"dsn"`
// Connection parameters (used if DSN is not provided)
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
// PostgreSQL/MSSQL specific
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
Schema string `mapstructure:"schema"` // Default schema
// SQLite specific
FilePath string `mapstructure:"filepath"`
// MongoDB specific
AuthSource string `mapstructure:"auth_source"`
ReplicaSet string `mapstructure:"replica_set"`
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
// Connection pool settings (overrides global defaults)
MaxOpenConns *int `mapstructure:"max_open_conns"`
MaxIdleConns *int `mapstructure:"max_idle_conns"`
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
// Timeouts
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
QueryTimeout time.Duration `mapstructure:"query_timeout"`
// Features
EnableTracing bool `mapstructure:"enable_tracing"`
EnableMetrics bool `mapstructure:"enable_metrics"`
EnableLogging bool `mapstructure:"enable_logging"`
// DefaultORM specifies which ORM to use for the Database() method
// Options: "bun", "gorm", "native"
DefaultORM string `mapstructure:"default_orm"`
// Tags for organization and filtering
Tags map[string]string `mapstructure:"tags"`
}
// DefaultManagerConfig returns a ManagerConfig with sensible defaults
func DefaultManagerConfig() ManagerConfig {
return ManagerConfig{
DefaultConnection: "",
Connections: make(map[string]ConnectionConfig),
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: 30 * time.Minute,
ConnMaxIdleTime: 5 * time.Minute,
RetryAttempts: 3,
RetryDelay: 1 * time.Second,
RetryMaxDelay: 10 * time.Second,
HealthCheckInterval: 30 * time.Second,
EnableAutoReconnect: true,
}
}
// ApplyDefaults applies default values to the manager configuration
func (c *ManagerConfig) ApplyDefaults() {
defaults := DefaultManagerConfig()
if c.MaxOpenConns == 0 {
c.MaxOpenConns = defaults.MaxOpenConns
}
if c.MaxIdleConns == 0 {
c.MaxIdleConns = defaults.MaxIdleConns
}
if c.ConnMaxLifetime == 0 {
c.ConnMaxLifetime = defaults.ConnMaxLifetime
}
if c.ConnMaxIdleTime == 0 {
c.ConnMaxIdleTime = defaults.ConnMaxIdleTime
}
if c.RetryAttempts == 0 {
c.RetryAttempts = defaults.RetryAttempts
}
if c.RetryDelay == 0 {
c.RetryDelay = defaults.RetryDelay
}
if c.RetryMaxDelay == 0 {
c.RetryMaxDelay = defaults.RetryMaxDelay
}
if c.HealthCheckInterval == 0 {
c.HealthCheckInterval = defaults.HealthCheckInterval
}
}
// Validate validates the manager configuration
func (c *ManagerConfig) Validate() error {
if len(c.Connections) == 0 {
return NewConfigurationError("connections", fmt.Errorf("at least one connection must be configured"))
}
if c.DefaultConnection != "" {
if _, ok := c.Connections[c.DefaultConnection]; !ok {
return NewConfigurationError("default_connection", fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection))
}
}
// Validate each connection
for name := range c.Connections {
conn := c.Connections[name]
if err := conn.Validate(); err != nil {
return fmt.Errorf("connection '%s': %w", name, err)
}
}
return nil
}
// ApplyDefaults applies default values and global settings to the connection configuration
func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
// Set name if not already set
if cc.Name == "" {
cc.Name = "unnamed"
}
// Apply global pool settings if not overridden
if cc.MaxOpenConns == nil && global != nil {
maxOpen := global.MaxOpenConns
cc.MaxOpenConns = &maxOpen
}
if cc.MaxIdleConns == nil && global != nil {
maxIdle := global.MaxIdleConns
cc.MaxIdleConns = &maxIdle
}
if cc.ConnMaxLifetime == nil && global != nil {
lifetime := global.ConnMaxLifetime
cc.ConnMaxLifetime = &lifetime
}
if cc.ConnMaxIdleTime == nil && global != nil {
idleTime := global.ConnMaxIdleTime
cc.ConnMaxIdleTime = &idleTime
}
// Default timeouts
if cc.ConnectTimeout == 0 {
cc.ConnectTimeout = 10 * time.Second
}
if cc.QueryTimeout == 0 {
cc.QueryTimeout = 30 * time.Second
}
// Default ORM
if cc.DefaultORM == "" {
cc.DefaultORM = string(ORMTypeBun)
}
// Default PostgreSQL port
if cc.Type == DatabaseTypePostgreSQL && cc.Port == 0 && cc.DSN == "" {
cc.Port = 5432
}
// Default MSSQL port
if cc.Type == DatabaseTypeMSSQL && cc.Port == 0 && cc.DSN == "" {
cc.Port = 1433
}
// Default MongoDB port
if cc.Type == DatabaseTypeMongoDB && cc.Port == 0 && cc.DSN == "" {
cc.Port = 27017
}
// Default MongoDB auth source
if cc.Type == DatabaseTypeMongoDB && cc.AuthSource == "" {
cc.AuthSource = "admin"
}
}
// Validate validates the connection configuration
func (cc *ConnectionConfig) Validate() error {
// Validate database type
switch cc.Type {
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
// Valid types
default:
return NewConfigurationError("type", fmt.Errorf("unsupported database type: %s", cc.Type))
}
// Validate that either DSN or connection parameters are provided
if cc.DSN == "" {
switch cc.Type {
case DatabaseTypePostgreSQL, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
if cc.Host == "" {
return NewConfigurationError("host", fmt.Errorf("host is required when DSN is not provided"))
}
if cc.Database == "" {
return NewConfigurationError("database", fmt.Errorf("database is required when DSN is not provided"))
}
case DatabaseTypeSQLite:
if cc.FilePath == "" {
return NewConfigurationError("filepath", fmt.Errorf("filepath is required for SQLite when DSN is not provided"))
}
}
}
// Validate ORM type
if cc.DefaultORM != "" {
switch ORMType(cc.DefaultORM) {
case ORMTypeBun, ORMTypeGORM, ORMTypeNative:
// Valid ORM types
default:
return NewConfigurationError("default_orm", fmt.Errorf("unsupported ORM type: %s", cc.DefaultORM))
}
}
return nil
}
// BuildDSN builds a connection string from individual parameters
func (cc *ConnectionConfig) BuildDSN() (string, error) {
// If DSN is already provided, use it
if cc.DSN != "" {
return cc.DSN, nil
}
switch cc.Type {
case DatabaseTypePostgreSQL:
return cc.buildPostgresDSN(), nil
case DatabaseTypeSQLite:
return cc.buildSQLiteDSN(), nil
case DatabaseTypeMSSQL:
return cc.buildMSSQLDSN(), nil
case DatabaseTypeMongoDB:
return cc.buildMongoDSN(), nil
default:
return "", fmt.Errorf("cannot build DSN for database type: %s", cc.Type)
}
}
func (cc *ConnectionConfig) buildPostgresDSN() string {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s",
cc.Host, cc.Port, cc.User, cc.Password, cc.Database)
if cc.SSLMode != "" {
dsn += fmt.Sprintf(" sslmode=%s", cc.SSLMode)
} else {
dsn += " sslmode=disable"
}
if cc.Schema != "" {
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
}
return dsn
}
func (cc *ConnectionConfig) buildSQLiteDSN() string {
if cc.FilePath != "" {
return cc.FilePath
}
return ":memory:"
}
func (cc *ConnectionConfig) buildMSSQLDSN() string {
// Format: sqlserver://username:password@host:port?database=dbname
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
if cc.Schema != "" {
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
}
return dsn
}
func (cc *ConnectionConfig) buildMongoDSN() string {
// Format: mongodb://username:password@host:port/database?authSource=admin
var dsn string
if cc.User != "" && cc.Password != "" {
dsn = fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
} else {
dsn = fmt.Sprintf("mongodb://%s:%d/%s", cc.Host, cc.Port, cc.Database)
}
params := ""
if cc.AuthSource != "" {
params += fmt.Sprintf("authSource=%s", cc.AuthSource)
}
if cc.ReplicaSet != "" {
if params != "" {
params += "&"
}
params += fmt.Sprintf("replicaSet=%s", cc.ReplicaSet)
}
if cc.ReadPreference != "" {
if params != "" {
params += "&"
}
params += fmt.Sprintf("readPreference=%s", cc.ReadPreference)
}
if params != "" {
dsn += "?" + params
}
return dsn
}
// FromConfig converts config.DBManagerConfig to internal ManagerConfig
func FromConfig(cfg config.DBManagerConfig) ManagerConfig {
mgr := ManagerConfig{
DefaultConnection: cfg.DefaultConnection,
Connections: make(map[string]ConnectionConfig),
MaxOpenConns: cfg.MaxOpenConns,
MaxIdleConns: cfg.MaxIdleConns,
ConnMaxLifetime: cfg.ConnMaxLifetime,
ConnMaxIdleTime: cfg.ConnMaxIdleTime,
RetryAttempts: cfg.RetryAttempts,
RetryDelay: cfg.RetryDelay,
RetryMaxDelay: cfg.RetryMaxDelay,
HealthCheckInterval: cfg.HealthCheckInterval,
EnableAutoReconnect: cfg.EnableAutoReconnect,
}
// Convert connections
for name := range cfg.Connections {
connCfg := cfg.Connections[name]
mgr.Connections[name] = ConnectionConfig{
Name: connCfg.Name,
Type: DatabaseType(connCfg.Type),
DSN: connCfg.DSN,
Host: connCfg.Host,
Port: connCfg.Port,
User: connCfg.User,
Password: connCfg.Password,
Database: connCfg.Database,
SSLMode: connCfg.SSLMode,
Schema: connCfg.Schema,
FilePath: connCfg.FilePath,
AuthSource: connCfg.AuthSource,
ReplicaSet: connCfg.ReplicaSet,
ReadPreference: connCfg.ReadPreference,
MaxOpenConns: connCfg.MaxOpenConns,
MaxIdleConns: connCfg.MaxIdleConns,
ConnMaxLifetime: connCfg.ConnMaxLifetime,
ConnMaxIdleTime: connCfg.ConnMaxIdleTime,
ConnectTimeout: connCfg.ConnectTimeout,
QueryTimeout: connCfg.QueryTimeout,
EnableTracing: connCfg.EnableTracing,
EnableMetrics: connCfg.EnableMetrics,
EnableLogging: connCfg.EnableLogging,
DefaultORM: connCfg.DefaultORM,
Tags: connCfg.Tags,
}
}
return mgr
}
// Getter methods to implement providers.ConnectionConfig interface
func (cc *ConnectionConfig) GetName() string { return cc.Name }
func (cc *ConnectionConfig) GetType() string { return string(cc.Type) }
func (cc *ConnectionConfig) GetHost() string { return cc.Host }
func (cc *ConnectionConfig) GetPort() int { return cc.Port }
func (cc *ConnectionConfig) GetUser() string { return cc.User }
func (cc *ConnectionConfig) GetPassword() string { return cc.Password }
func (cc *ConnectionConfig) GetDatabase() string { return cc.Database }
func (cc *ConnectionConfig) GetFilePath() string { return cc.FilePath }
func (cc *ConnectionConfig) GetConnectTimeout() time.Duration { return cc.ConnectTimeout }
func (cc *ConnectionConfig) GetEnableLogging() bool { return cc.EnableLogging }
func (cc *ConnectionConfig) GetMaxOpenConns() *int { return cc.MaxOpenConns }
func (cc *ConnectionConfig) GetMaxIdleConns() *int { return cc.MaxIdleConns }
func (cc *ConnectionConfig) GetConnMaxLifetime() *time.Duration { return cc.ConnMaxLifetime }
func (cc *ConnectionConfig) GetConnMaxIdleTime() *time.Duration { return cc.ConnMaxIdleTime }
func (cc *ConnectionConfig) GetQueryTimeout() time.Duration { return cc.QueryTimeout }
func (cc *ConnectionConfig) GetEnableMetrics() bool { return cc.EnableMetrics }
func (cc *ConnectionConfig) GetReadPreference() string { return cc.ReadPreference }

607
pkg/dbmanager/connection.go Normal file
View File

@@ -0,0 +1,607 @@
package dbmanager
import (
"context"
"database/sql"
"sync"
"time"
"github.com/uptrace/bun"
"github.com/uptrace/bun/schema"
"go.mongodb.org/mongo-driver/mongo"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
)
// Connection represents a single named database connection
type Connection interface {
// Metadata
Name() string
Type() DatabaseType
// ORM Access (SQL databases only)
Bun() (*bun.DB, error)
GORM() (*gorm.DB, error)
Native() (*sql.DB, error)
// Common Database interface (for SQL databases)
Database() (common.Database, error)
// MongoDB Access (MongoDB only)
MongoDB() (*mongo.Client, error)
// Lifecycle
Connect(ctx context.Context) error
Close() error
HealthCheck(ctx context.Context) error
Reconnect(ctx context.Context) error
// Stats
Stats() *ConnectionStats
}
// ConnectionStats contains statistics about a database connection
type ConnectionStats struct {
Name string
Type DatabaseType
Connected bool
LastHealthCheck time.Time
HealthCheckStatus string
// SQL connection pool stats
OpenConnections int
InUse int
Idle int
WaitCount int64
WaitDuration time.Duration
MaxIdleClosed int64
MaxLifetimeClosed int64
}
// sqlConnection implements Connection for SQL databases (PostgreSQL, SQLite, MSSQL)
type sqlConnection struct {
name string
dbType DatabaseType
config ConnectionConfig
provider Provider
// Lazy-initialized ORM instances (all wrap the same sql.DB)
nativeDB *sql.DB
bunDB *bun.DB
gormDB *gorm.DB
// Adapters for common.Database interface
bunAdapter *database.BunAdapter
gormAdapter *database.GormAdapter
nativeAdapter common.Database
// State
connected bool
mu sync.RWMutex
// Health check
lastHealthCheck time.Time
healthCheckStatus string
}
// newSQLConnection creates a new SQL connection
func newSQLConnection(name string, dbType DatabaseType, config ConnectionConfig, provider Provider) *sqlConnection {
return &sqlConnection{
name: name,
dbType: dbType,
config: config,
provider: provider,
}
}
// Name returns the connection name
func (c *sqlConnection) Name() string {
return c.name
}
// Type returns the database type
func (c *sqlConnection) Type() DatabaseType {
return c.dbType
}
// Connect establishes the database connection
func (c *sqlConnection) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return ErrAlreadyConnected
}
if err := c.provider.Connect(ctx, &c.config); err != nil {
return NewConnectionError(c.name, "connect", err)
}
c.connected = true
return nil
}
// Close closes the database connection and all ORM instances
func (c *sqlConnection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
// Close Bun if initialized
if c.bunDB != nil {
if err := c.bunDB.Close(); err != nil {
return NewConnectionError(c.name, "close bun", err)
}
}
// GORM doesn't have a separate close - it uses the underlying sql.DB
// Close the provider (which closes the underlying sql.DB)
if err := c.provider.Close(); err != nil {
return NewConnectionError(c.name, "close", err)
}
c.connected = false
c.nativeDB = nil
c.bunDB = nil
c.gormDB = nil
c.bunAdapter = nil
c.gormAdapter = nil
c.nativeAdapter = nil
return nil
}
// HealthCheck verifies the connection is alive
func (c *sqlConnection) HealthCheck(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lastHealthCheck = time.Now()
if !c.connected {
c.healthCheckStatus = "disconnected"
return ErrConnectionClosed
}
if err := c.provider.HealthCheck(ctx); err != nil {
c.healthCheckStatus = "unhealthy: " + err.Error()
return NewConnectionError(c.name, "health check", err)
}
c.healthCheckStatus = "healthy"
return nil
}
// Reconnect closes and re-establishes the connection
func (c *sqlConnection) Reconnect(ctx context.Context) error {
if err := c.Close(); err != nil {
return err
}
return c.Connect(ctx)
}
// Native returns the native *sql.DB connection
func (c *sqlConnection) Native() (*sql.DB, error) {
c.mu.RLock()
if c.nativeDB != nil {
defer c.mu.RUnlock()
return c.nativeDB, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
// Double-check after acquiring write lock
if c.nativeDB != nil {
return c.nativeDB, nil
}
if !c.connected {
return nil, ErrConnectionClosed
}
// Get native connection from provider
db, err := c.provider.GetNative()
if err != nil {
return nil, NewConnectionError(c.name, "get native", err)
}
c.nativeDB = db
return c.nativeDB, nil
}
// Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) {
c.mu.RLock()
if c.bunDB != nil {
defer c.mu.RUnlock()
return c.bunDB, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
// Double-check after acquiring write lock
if c.bunDB != nil {
return c.bunDB, nil
}
// Get native connection first
native, err := c.provider.GetNative()
if err != nil {
return nil, NewConnectionError(c.name, "get bun", err)
}
// Create Bun DB wrapping the same sql.DB
dialect := c.getBunDialect()
c.bunDB = bun.NewDB(native, dialect)
return c.bunDB, nil
}
// GORM returns a GORM instance wrapping the native connection
func (c *sqlConnection) GORM() (*gorm.DB, error) {
c.mu.RLock()
if c.gormDB != nil {
defer c.mu.RUnlock()
return c.gormDB, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
// Double-check after acquiring write lock
if c.gormDB != nil {
return c.gormDB, nil
}
// Get native connection first
native, err := c.provider.GetNative()
if err != nil {
return nil, NewConnectionError(c.name, "get gorm", err)
}
// Create GORM DB wrapping the same sql.DB
dialector := c.getGORMDialector(native)
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, NewConnectionError(c.name, "initialize gorm", err)
}
c.gormDB = db
return c.gormDB, nil
}
// Database returns the common.Database interface using the configured default ORM
func (c *sqlConnection) Database() (common.Database, error) {
c.mu.RLock()
defaultORM := c.config.DefaultORM
c.mu.RUnlock()
switch ORMType(defaultORM) {
case ORMTypeBun:
return c.getBunAdapter()
case ORMTypeGORM:
return c.getGORMAdapter()
case ORMTypeNative:
return c.getNativeAdapter()
default:
// Default to Bun
return c.getBunAdapter()
}
}
// MongoDB returns an error for SQL connections
func (c *sqlConnection) MongoDB() (*mongo.Client, error) {
return nil, ErrNotMongoDB
}
// Stats returns connection statistics
func (c *sqlConnection) Stats() *ConnectionStats {
c.mu.RLock()
defer c.mu.RUnlock()
stats := &ConnectionStats{
Name: c.name,
Type: c.dbType,
Connected: c.connected,
LastHealthCheck: c.lastHealthCheck,
HealthCheckStatus: c.healthCheckStatus,
}
// Get SQL stats if connected
if c.connected && c.provider != nil {
if providerStats := c.provider.Stats(); providerStats != nil {
stats.OpenConnections = providerStats.OpenConnections
stats.InUse = providerStats.InUse
stats.Idle = providerStats.Idle
stats.WaitCount = providerStats.WaitCount
stats.WaitDuration = providerStats.WaitDuration
stats.MaxIdleClosed = providerStats.MaxIdleClosed
stats.MaxLifetimeClosed = providerStats.MaxLifetimeClosed
}
}
return stats
}
// getBunAdapter returns or creates the Bun adapter
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
c.mu.RLock()
if c.bunAdapter != nil {
defer c.mu.RUnlock()
return c.bunAdapter, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.bunAdapter != nil {
return c.bunAdapter, nil
}
bunDB, err := c.Bun()
if err != nil {
return nil, err
}
c.bunAdapter = database.NewBunAdapter(bunDB)
return c.bunAdapter, nil
}
// getGORMAdapter returns or creates the GORM adapter
func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
c.mu.RLock()
if c.gormAdapter != nil {
defer c.mu.RUnlock()
return c.gormAdapter, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.gormAdapter != nil {
return c.gormAdapter, nil
}
gormDB, err := c.GORM()
if err != nil {
return nil, err
}
c.gormAdapter = database.NewGormAdapter(gormDB)
return c.gormAdapter, nil
}
// getNativeAdapter returns or creates the native adapter
func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
c.mu.RLock()
if c.nativeAdapter != nil {
defer c.mu.RUnlock()
return c.nativeAdapter, nil
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.nativeAdapter != nil {
return c.nativeAdapter, nil
}
native, err := c.Native()
if err != nil {
return nil, err
}
// Create a native adapter based on database type
switch c.dbType {
case DatabaseTypePostgreSQL:
c.nativeAdapter = database.NewPgSQLAdapter(native)
case DatabaseTypeSQLite:
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
c.nativeAdapter = database.NewPgSQLAdapter(native)
case DatabaseTypeMSSQL:
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
c.nativeAdapter = database.NewPgSQLAdapter(native)
default:
return nil, ErrUnsupportedDatabase
}
return c.nativeAdapter, nil
}
// getBunDialect returns the appropriate Bun dialect for the database type
func (c *sqlConnection) getBunDialect() schema.Dialect {
switch c.dbType {
case DatabaseTypePostgreSQL:
return database.GetPostgresDialect()
case DatabaseTypeSQLite:
return database.GetSQLiteDialect()
case DatabaseTypeMSSQL:
return database.GetMSSQLDialect()
default:
// Default to PostgreSQL
return database.GetPostgresDialect()
}
}
// getGORMDialector returns the appropriate GORM dialector for the database type
func (c *sqlConnection) getGORMDialector(db *sql.DB) gorm.Dialector {
switch c.dbType {
case DatabaseTypePostgreSQL:
return database.GetPostgresDialector(db)
case DatabaseTypeSQLite:
return database.GetSQLiteDialector(db)
case DatabaseTypeMSSQL:
return database.GetMSSQLDialector(db)
default:
// Default to PostgreSQL
return database.GetPostgresDialector(db)
}
}
// mongoConnection implements Connection for MongoDB
type mongoConnection struct {
name string
config ConnectionConfig
provider Provider
// MongoDB client
client *mongo.Client
// State
connected bool
mu sync.RWMutex
// Health check
lastHealthCheck time.Time
healthCheckStatus string
}
// newMongoConnection creates a new MongoDB connection
func newMongoConnection(name string, config ConnectionConfig, provider Provider) *mongoConnection {
return &mongoConnection{
name: name,
config: config,
provider: provider,
}
}
// Name returns the connection name
func (c *mongoConnection) Name() string {
return c.name
}
// Type returns the database type (MongoDB)
func (c *mongoConnection) Type() DatabaseType {
return DatabaseTypeMongoDB
}
// Connect establishes the MongoDB connection
func (c *mongoConnection) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return ErrAlreadyConnected
}
if err := c.provider.Connect(ctx, &c.config); err != nil {
return NewConnectionError(c.name, "connect", err)
}
// Get the mongo client
client, err := c.provider.GetMongo()
if err != nil {
return NewConnectionError(c.name, "get mongo client", err)
}
c.client = client
c.connected = true
return nil
}
// Close closes the MongoDB connection
func (c *mongoConnection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
if err := c.provider.Close(); err != nil {
return NewConnectionError(c.name, "close", err)
}
c.connected = false
c.client = nil
return nil
}
// HealthCheck verifies the MongoDB connection is alive
func (c *mongoConnection) HealthCheck(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lastHealthCheck = time.Now()
if !c.connected {
c.healthCheckStatus = "disconnected"
return ErrConnectionClosed
}
if err := c.provider.HealthCheck(ctx); err != nil {
c.healthCheckStatus = "unhealthy: " + err.Error()
return NewConnectionError(c.name, "health check", err)
}
c.healthCheckStatus = "healthy"
return nil
}
// Reconnect closes and re-establishes the MongoDB connection
func (c *mongoConnection) Reconnect(ctx context.Context) error {
if err := c.Close(); err != nil {
return err
}
return c.Connect(ctx)
}
// MongoDB returns the MongoDB client
func (c *mongoConnection) MongoDB() (*mongo.Client, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.client == nil {
return nil, ErrConnectionClosed
}
return c.client, nil
}
// Bun returns an error for MongoDB connections
func (c *mongoConnection) Bun() (*bun.DB, error) {
return nil, ErrNotSQLDatabase
}
// GORM returns an error for MongoDB connections
func (c *mongoConnection) GORM() (*gorm.DB, error) {
return nil, ErrNotSQLDatabase
}
// Native returns an error for MongoDB connections
func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase
}
// Stats returns connection statistics for MongoDB
func (c *mongoConnection) Stats() *ConnectionStats {
c.mu.RLock()
defer c.mu.RUnlock()
return &ConnectionStats{
Name: c.name,
Type: DatabaseTypeMongoDB,
Connected: c.connected,
LastHealthCheck: c.lastHealthCheck,
HealthCheckStatus: c.healthCheckStatus,
}
}

82
pkg/dbmanager/errors.go Normal file
View File

@@ -0,0 +1,82 @@
package dbmanager
import (
"errors"
"fmt"
)
// Common errors
var (
// ErrConnectionNotFound is returned when a connection with the given name doesn't exist
ErrConnectionNotFound = errors.New("connection not found")
// ErrInvalidConfiguration is returned when the configuration is invalid
ErrInvalidConfiguration = errors.New("invalid configuration")
// ErrConnectionClosed is returned when attempting to use a closed connection
ErrConnectionClosed = errors.New("connection is closed")
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
ErrNotSQLDatabase = errors.New("not a SQL database")
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
ErrNotMongoDB = errors.New("not a MongoDB connection")
// ErrUnsupportedDatabase is returned when the database type is not supported
ErrUnsupportedDatabase = errors.New("unsupported database type")
// ErrNoDefaultConnection is returned when no default connection is configured
ErrNoDefaultConnection = errors.New("no default connection configured")
// ErrAlreadyConnected is returned when attempting to connect an already connected connection
ErrAlreadyConnected = errors.New("already connected")
)
// ConnectionError wraps errors that occur during connection operations
type ConnectionError struct {
Name string
Operation string
Err error
}
func (e *ConnectionError) Error() string {
return fmt.Sprintf("connection '%s' %s: %v", e.Name, e.Operation, e.Err)
}
func (e *ConnectionError) Unwrap() error {
return e.Err
}
// NewConnectionError creates a new ConnectionError
func NewConnectionError(name, operation string, err error) *ConnectionError {
return &ConnectionError{
Name: name,
Operation: operation,
Err: err,
}
}
// ConfigurationError wraps configuration-related errors
type ConfigurationError struct {
Field string
Err error
}
func (e *ConfigurationError) Error() string {
if e.Field != "" {
return fmt.Sprintf("configuration error in field '%s': %v", e.Field, e.Err)
}
return fmt.Sprintf("configuration error: %v", e.Err)
}
func (e *ConfigurationError) Unwrap() error {
return e.Err
}
// NewConfigurationError creates a new ConfigurationError
func NewConfigurationError(field string, err error) *ConfigurationError {
return &ConfigurationError{
Field: field,
Err: err,
}
}

51
pkg/dbmanager/factory.go Normal file
View File

@@ -0,0 +1,51 @@
package dbmanager
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
)
// createConnection creates a database connection based on the configuration
func createConnection(cfg ConnectionConfig) (Connection, error) {
// Validate configuration
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid connection configuration: %w", err)
}
// Create provider based on database type
provider, err := createProvider(cfg.Type)
if err != nil {
return nil, err
}
// Create connection wrapper based on database type
switch cfg.Type {
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL:
return newSQLConnection(cfg.Name, cfg.Type, cfg, provider), nil
case DatabaseTypeMongoDB:
return newMongoConnection(cfg.Name, cfg, provider), nil
default:
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, cfg.Type)
}
}
// createProvider creates a database provider based on the database type
func createProvider(dbType DatabaseType) (Provider, error) {
switch dbType {
case DatabaseTypePostgreSQL:
return providers.NewPostgresProvider(), nil
case DatabaseTypeSQLite:
return providers.NewSQLiteProvider(), nil
case DatabaseTypeMSSQL:
return providers.NewMSSQLProvider(), nil
case DatabaseTypeMongoDB:
return providers.NewMongoProvider(), nil
default:
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, dbType)
}
}
// Provider is an alias to the providers.Provider interface
// This allows dbmanager package consumers to use Provider without importing providers
type Provider = providers.Provider

326
pkg/dbmanager/manager.go Normal file
View File

@@ -0,0 +1,326 @@
package dbmanager
import (
"context"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Manager manages multiple named database connections
type Manager interface {
// Connection retrieval
Get(name string) (Connection, error)
GetDefault() (Connection, error)
GetAll() map[string]Connection
// Default database management
GetDefaultDatabase() (common.Database, error)
SetDefaultDatabase(name string) error
// Lifecycle
Connect(ctx context.Context) error
Close() error
HealthCheck(ctx context.Context) error
// Stats
Stats() *ManagerStats
}
// ManagerStats contains statistics about the connection manager
type ManagerStats struct {
TotalConnections int
HealthyCount int
UnhealthyCount int
ConnectionStats map[string]*ConnectionStats
}
// connectionManager implements Manager
type connectionManager struct {
connections map[string]Connection
config ManagerConfig
mu sync.RWMutex
// Background health check
healthTicker *time.Ticker
stopChan chan struct{}
wg sync.WaitGroup
}
// NewManager creates a new database connection manager
func NewManager(cfg ManagerConfig) (Manager, error) {
// Apply defaults and validate configuration
cfg.ApplyDefaults()
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
mgr := &connectionManager{
connections: make(map[string]Connection),
config: cfg,
stopChan: make(chan struct{}),
}
return mgr, nil
}
// Get retrieves a named connection
func (m *connectionManager) Get(name string) (Connection, error) {
m.mu.RLock()
defer m.mu.RUnlock()
conn, ok := m.connections[name]
if !ok {
return nil, fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
}
return conn, nil
}
// GetDefault retrieves the default connection
func (m *connectionManager) GetDefault() (Connection, error) {
m.mu.RLock()
defaultName := m.config.DefaultConnection
m.mu.RUnlock()
if defaultName == "" {
return nil, ErrNoDefaultConnection
}
return m.Get(defaultName)
}
// GetAll returns all connections
func (m *connectionManager) GetAll() map[string]Connection {
m.mu.RLock()
defer m.mu.RUnlock()
// Create a copy to avoid concurrent access issues
result := make(map[string]Connection, len(m.connections))
for name, conn := range m.connections {
result[name] = conn
}
return result
}
// GetDefaultDatabase returns the common.Database interface from the default connection
func (m *connectionManager) GetDefaultDatabase() (common.Database, error) {
conn, err := m.GetDefault()
if err != nil {
return nil, err
}
db, err := conn.Database()
if err != nil {
return nil, fmt.Errorf("failed to get database from default connection: %w", err)
}
return db, nil
}
// SetDefaultDatabase sets the default database connection by name
func (m *connectionManager) SetDefaultDatabase(name string) error {
m.mu.Lock()
defer m.mu.Unlock()
// Verify the connection exists
if _, ok := m.connections[name]; !ok {
return fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
}
m.config.DefaultConnection = name
logger.Info("Default database connection changed: name=%s", name)
return nil
}
// Connect establishes all configured database connections
func (m *connectionManager) Connect(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
// Create connections from configuration
for name := range m.config.Connections {
// Get a copy of the connection config
connCfg := m.config.Connections[name]
// Apply global defaults to connection config
connCfg.ApplyDefaults(&m.config)
connCfg.Name = name
// Create connection using factory
conn, err := createConnection(connCfg)
if err != nil {
return fmt.Errorf("failed to create connection '%s': %w", name, err)
}
// Connect
if err := conn.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect '%s': %w", name, err)
}
m.connections[name] = conn
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
}
// Start background health checks if enabled
if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 {
m.startHealthChecker()
}
logger.Info("Database manager initialized: connections=%d", len(m.connections))
return nil
}
// Close closes all database connections
func (m *connectionManager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
// Stop health checker
m.stopHealthChecker()
// Close all connections
var errors []error
for name, conn := range m.connections {
if err := conn.Close(); err != nil {
errors = append(errors, fmt.Errorf("failed to close connection '%s': %w", name, err))
logger.Error("Failed to close connection", "name", name, "error", err)
} else {
logger.Info("Connection closed: name=%s", name)
}
}
m.connections = make(map[string]Connection)
if len(errors) > 0 {
return fmt.Errorf("errors closing connections: %v", errors)
}
logger.Info("Database manager closed")
return nil
}
// HealthCheck performs health checks on all connections
func (m *connectionManager) HealthCheck(ctx context.Context) error {
m.mu.RLock()
connections := make(map[string]Connection, len(m.connections))
for name, conn := range m.connections {
connections[name] = conn
}
m.mu.RUnlock()
var errors []error
for name, conn := range connections {
if err := conn.HealthCheck(ctx); err != nil {
errors = append(errors, fmt.Errorf("connection '%s': %w", name, err))
}
}
if len(errors) > 0 {
return fmt.Errorf("health check failed for %d connections: %v", len(errors), errors)
}
return nil
}
// Stats returns statistics for all connections
func (m *connectionManager) Stats() *ManagerStats {
m.mu.RLock()
defer m.mu.RUnlock()
stats := &ManagerStats{
TotalConnections: len(m.connections),
ConnectionStats: make(map[string]*ConnectionStats),
}
for name, conn := range m.connections {
connStats := conn.Stats()
stats.ConnectionStats[name] = connStats
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
stats.HealthyCount++
} else {
stats.UnhealthyCount++
}
}
return stats
}
// startHealthChecker starts background health checking
func (m *connectionManager) startHealthChecker() {
if m.healthTicker != nil {
return // Already running
}
m.healthTicker = time.NewTicker(m.config.HealthCheckInterval)
m.wg.Add(1)
go func() {
defer m.wg.Done()
logger.Info("Health checker started: interval=%v", m.config.HealthCheckInterval)
for {
select {
case <-m.healthTicker.C:
m.performHealthCheck()
case <-m.stopChan:
logger.Info("Health checker stopped")
return
}
}
}()
}
// stopHealthChecker stops background health checking
func (m *connectionManager) stopHealthChecker() {
if m.healthTicker != nil {
m.healthTicker.Stop()
close(m.stopChan)
m.wg.Wait()
m.healthTicker = nil
}
}
// performHealthCheck performs a health check on all connections
func (m *connectionManager) performHealthCheck() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
m.mu.RLock()
connections := make([]struct {
name string
conn Connection
}, 0, len(m.connections))
for name, conn := range m.connections {
connections = append(connections, struct {
name string
conn Connection
}{name, conn})
}
m.mu.RUnlock()
for _, item := range connections {
if err := item.conn.HealthCheck(ctx); err != nil {
logger.Warn("Health check failed",
"connection", item.name,
"error", err)
// Attempt reconnection if enabled
if m.config.EnableAutoReconnect {
logger.Info("Attempting reconnection: connection=%s", item.name)
if err := item.conn.Reconnect(ctx); err != nil {
logger.Error("Reconnection failed",
"connection", item.name,
"error", err)
} else {
logger.Info("Reconnection successful: connection=%s", item.name)
}
}
}
}
}

136
pkg/dbmanager/metrics.go Normal file
View File

@@ -0,0 +1,136 @@
package dbmanager
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// connectionsTotal tracks the total number of configured database connections
connectionsTotal = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connections_total",
Help: "Total number of configured database connections",
},
[]string{"type"},
)
// connectionStatus tracks connection health status (1=healthy, 0=unhealthy)
connectionStatus = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_status",
Help: "Connection status (1=healthy, 0=unhealthy)",
},
[]string{"name", "type"},
)
// connectionPoolSize tracks connection pool sizes
connectionPoolSize = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_pool_size",
Help: "Current connection pool size",
},
[]string{"name", "type", "state"}, // state: open, idle, in_use
)
// connectionWaitCount tracks how many times connections had to wait for availability
connectionWaitCount = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_wait_count",
Help: "Number of times connections had to wait for availability",
},
[]string{"name", "type"},
)
// connectionWaitDuration tracks total time connections spent waiting
connectionWaitDuration = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_wait_duration_seconds",
Help: "Total time connections spent waiting for availability",
},
[]string{"name", "type"},
)
// reconnectAttempts tracks reconnection attempts and their outcomes
reconnectAttempts = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "dbmanager_reconnect_attempts_total",
Help: "Total number of reconnection attempts",
},
[]string{"name", "type", "result"}, // result: success, failure
)
// connectionLifetimeClosed tracks connections closed due to max lifetime
connectionLifetimeClosed = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_lifetime_closed_total",
Help: "Total connections closed due to exceeding max lifetime",
},
[]string{"name", "type"},
)
// connectionIdleClosed tracks connections closed due to max idle time
connectionIdleClosed = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "dbmanager_connection_idle_closed_total",
Help: "Total connections closed due to exceeding max idle time",
},
[]string{"name", "type"},
)
)
// PublishMetrics publishes current metrics for all connections
func (m *connectionManager) PublishMetrics() {
stats := m.Stats()
// Count connections by type
typeCount := make(map[DatabaseType]int)
for _, connStats := range stats.ConnectionStats {
typeCount[connStats.Type]++
}
// Update total connections gauge
for dbType, count := range typeCount {
connectionsTotal.WithLabelValues(string(dbType)).Set(float64(count))
}
// Update per-connection metrics
for name, connStats := range stats.ConnectionStats {
labels := prometheus.Labels{
"name": name,
"type": string(connStats.Type),
}
// Connection status
status := float64(0)
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
status = 1
}
connectionStatus.With(labels).Set(status)
// Pool size metrics (SQL databases only)
if connStats.Type != DatabaseTypeMongoDB {
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "open").Set(float64(connStats.OpenConnections))
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "idle").Set(float64(connStats.Idle))
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "in_use").Set(float64(connStats.InUse))
// Wait stats
connectionWaitCount.With(labels).Set(float64(connStats.WaitCount))
connectionWaitDuration.With(labels).Set(connStats.WaitDuration.Seconds())
// Lifetime/idle closed
connectionLifetimeClosed.With(labels).Set(float64(connStats.MaxLifetimeClosed))
connectionIdleClosed.With(labels).Set(float64(connStats.MaxIdleClosed))
}
}
}
// RecordReconnectAttempt records a reconnection attempt
func RecordReconnectAttempt(name string, dbType DatabaseType, success bool) {
result := "failure"
if success {
result = "success"
}
reconnectAttempts.WithLabelValues(name, string(dbType), result).Inc()
}

View File

@@ -0,0 +1,319 @@
# PostgreSQL NOTIFY/LISTEN Support
The `dbmanager` package provides built-in support for PostgreSQL's NOTIFY/LISTEN functionality through the `PostgresListener` type.
## Overview
PostgreSQL NOTIFY/LISTEN is a simple pub/sub mechanism that allows database clients to:
- **LISTEN** on named channels to receive notifications
- **NOTIFY** channels to send messages to all listeners
- Receive asynchronous notifications without polling
## Features
- ✅ Subscribe to multiple channels simultaneously
- ✅ Callback-based notification handling
- ✅ Automatic reconnection on connection loss
- ✅ Automatic resubscription after reconnection
- ✅ Thread-safe operations
- ✅ Panic recovery in notification handlers
- ✅ Dedicated connection for listening (doesn't interfere with queries)
## Usage
### Basic Usage
```go
package main
import (
"context"
"fmt"
"time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
)
func main() {
// Create PostgreSQL provider
cfg := &providers.Config{
Name: "primary",
Type: "postgres",
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "password",
Database: "myapp",
}
provider := providers.NewPostgresProvider()
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(err)
}
defer provider.Close()
// Get listener
listener, err := provider.GetListener(ctx)
if err != nil {
panic(err)
}
// Subscribe to a channel
err = listener.Listen("events", func(channel, payload string) {
fmt.Printf("Received on %s: %s\n", channel, payload)
})
if err != nil {
panic(err)
}
// Send a notification
err = listener.Notify(ctx, "events", "Hello, World!")
if err != nil {
panic(err)
}
// Keep the program running
time.Sleep(1 * time.Second)
}
```
### Multiple Channels
```go
listener, _ := provider.GetListener(ctx)
// Listen to different channels with different handlers
listener.Listen("user_events", func(channel, payload string) {
fmt.Printf("User event: %s\n", payload)
})
listener.Listen("order_events", func(channel, payload string) {
fmt.Printf("Order event: %s\n", payload)
})
listener.Listen("payment_events", func(channel, payload string) {
fmt.Printf("Payment event: %s\n", payload)
})
```
### Unsubscribing
```go
// Stop listening to a specific channel
err := listener.Unlisten("user_events")
if err != nil {
fmt.Printf("Failed to unlisten: %v\n", err)
}
```
### Checking Active Channels
```go
// Get list of channels currently being listened to
channels := listener.Channels()
fmt.Printf("Listening to: %v\n", channels)
```
### Checking Connection Status
```go
if listener.IsConnected() {
fmt.Println("Listener is connected")
} else {
fmt.Println("Listener is disconnected")
}
```
## Integration with DBManager
When using the DBManager, you can access the listener through the PostgreSQL provider:
```go
// Initialize DBManager
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
mgr.Connect(ctx)
defer mgr.Close()
// Get PostgreSQL connection
conn, err := mgr.Get("primary")
// Note: You'll need to cast to the underlying provider type
// This requires exposing the provider through the Connection interface
// or providing a helper method
```
## Use Cases
### Cache Invalidation
```go
listener.Listen("cache_invalidation", func(channel, payload string) {
// Parse the payload to determine what to invalidate
cache.Invalidate(payload)
})
```
### Real-time Updates
```go
listener.Listen("data_updates", func(channel, payload string) {
// Broadcast update to WebSocket clients
websocketBroadcast(payload)
})
```
### Configuration Reload
```go
listener.Listen("config_reload", func(channel, payload string) {
// Reload application configuration
config.Reload()
})
```
### Distributed Locking
```go
listener.Listen("lock_released", func(channel, payload string) {
// Attempt to acquire the lock
tryAcquireLock(payload)
})
```
## Automatic Reconnection
The listener automatically handles connection failures:
1. When a connection error is detected, the listener initiates reconnection
2. Once reconnected, it automatically resubscribes to all previous channels
3. Notification handlers remain active throughout the reconnection process
No manual intervention is required for reconnection.
## Error Handling
### Handler Panics
If a notification handler panics, the panic is recovered and logged. The listener continues to function normally:
```go
listener.Listen("events", func(channel, payload string) {
defer func() {
if r := recover(); r != nil {
log.Printf("Handler panic: %v", r)
}
}()
// Your event processing logic
processEvent(payload)
})
```
### Connection Errors
Connection errors trigger automatic reconnection. Check logs for reconnection events when `EnableLogging` is true.
## Thread Safety
All `PostgresListener` methods are thread-safe and can be called concurrently from multiple goroutines.
## Performance Considerations
1. **Dedicated Connection**: The listener uses a dedicated PostgreSQL connection separate from the query connection pool
2. **Asynchronous Handlers**: Notification handlers run in separate goroutines to avoid blocking
3. **Lightweight**: NOTIFY/LISTEN has minimal overhead compared to polling
## Comparison with Polling
| Feature | NOTIFY/LISTEN | Polling |
|---------|---------------|---------|
| Latency | Low (near real-time) | High (depends on poll interval) |
| Database Load | Minimal | High (constant queries) |
| Scalability | Excellent | Poor |
| Complexity | Simple | Moderate |
## Limitations
1. **PostgreSQL Only**: This feature is specific to PostgreSQL and not available for other databases
2. **No Message Persistence**: Notifications are not stored; if no listener is connected, the message is lost
3. **Payload Limit**: Notification payload is limited to 8000 bytes in PostgreSQL
4. **No Guaranteed Delivery**: If a listener disconnects, in-flight notifications may be lost
## Best Practices
1. **Keep Handlers Fast**: Notification handlers should be quick; for heavy processing, send work to a queue
2. **Use JSON Payloads**: Encode structured data as JSON for easy parsing
3. **Handle Errors Gracefully**: Always recover from panics in handlers
4. **Close Properly**: Always close the provider to ensure the listener is properly shut down
5. **Monitor Connection Status**: Use `IsConnected()` for health checks
## Example: Real-World Application
```go
// Subscribe to various application events
listener, _ := provider.GetListener(ctx)
// User registration events
listener.Listen("user_registered", func(channel, payload string) {
var event UserRegisteredEvent
json.Unmarshal([]byte(payload), &event)
// Send welcome email
sendWelcomeEmail(event.UserID)
// Invalidate user count cache
cache.Delete("user_count")
})
// Order placement events
listener.Listen("order_placed", func(channel, payload string) {
var event OrderPlacedEvent
json.Unmarshal([]byte(payload), &event)
// Notify warehouse system
warehouse.ProcessOrder(event.OrderID)
// Update inventory cache
cache.Invalidate("inventory:" + event.ProductID)
})
// Configuration changes
listener.Listen("config_updated", func(channel, payload string) {
// Reload configuration from database
appConfig.Reload()
})
```
## Triggering Notifications from SQL
You can trigger notifications directly from PostgreSQL triggers or functions:
```sql
-- Example trigger to notify on new user
CREATE OR REPLACE FUNCTION notify_user_registered()
RETURNS TRIGGER AS $$
BEGIN
PERFORM pg_notify('user_registered',
json_build_object(
'user_id', NEW.id,
'email', NEW.email,
'timestamp', NOW()
)::text
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER user_registered_trigger
AFTER INSERT ON users
FOR EACH ROW
EXECUTE FUNCTION notify_user_registered();
```
## Additional Resources
- [PostgreSQL NOTIFY Documentation](https://www.postgresql.org/docs/current/sql-notify.html)
- [PostgreSQL LISTEN Documentation](https://www.postgresql.org/docs/current/sql-listen.html)
- [pgx Driver Documentation](https://github.com/jackc/pgx)

View File

@@ -0,0 +1,214 @@
package providers
import (
"context"
"database/sql"
"fmt"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// MongoProvider implements Provider for MongoDB databases
type MongoProvider struct {
client *mongo.Client
config ConnectionConfig
}
// NewMongoProvider creates a new MongoDB provider
func NewMongoProvider() *MongoProvider {
return &MongoProvider{}
}
// Connect establishes a MongoDB connection
func (p *MongoProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
// Build DSN
dsn, err := cfg.BuildDSN()
if err != nil {
return fmt.Errorf("failed to build DSN: %w", err)
}
// Create client options
clientOpts := options.Client().ApplyURI(dsn)
// Set connection pool size
if cfg.GetMaxOpenConns() != nil {
maxPoolSize := uint64(*cfg.GetMaxOpenConns())
clientOpts.SetMaxPoolSize(maxPoolSize)
}
if cfg.GetMaxIdleConns() != nil {
minPoolSize := uint64(*cfg.GetMaxIdleConns())
clientOpts.SetMinPoolSize(minPoolSize)
}
// Set timeouts
clientOpts.SetConnectTimeout(cfg.GetConnectTimeout())
if cfg.GetQueryTimeout() > 0 {
clientOpts.SetTimeout(cfg.GetQueryTimeout())
}
// Set read preference if specified
if cfg.GetReadPreference() != "" {
rp, err := parseReadPreference(cfg.GetReadPreference())
if err != nil {
return fmt.Errorf("invalid read preference: %w", err)
}
clientOpts.SetReadPreference(rp)
}
// Connect with retry logic
var client *mongo.Client
var lastErr error
retryAttempts := 3
retryDelay := 1 * time.Second
for attempt := 0; attempt < retryAttempts; attempt++ {
if attempt > 0 {
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
if cfg.GetEnableLogging() {
logger.Info("Retrying MongoDB connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
}
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
}
// Create MongoDB client
client, err = mongo.Connect(ctx, clientOpts)
if err != nil {
lastErr = err
if cfg.GetEnableLogging() {
logger.Warn("Failed to connect to MongoDB", "error", err)
}
continue
}
// Ping the database to verify connection
pingCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
err = client.Ping(pingCtx, readpref.Primary())
cancel()
if err != nil {
lastErr = err
_ = client.Disconnect(ctx)
if cfg.GetEnableLogging() {
logger.Warn("Failed to ping MongoDB", "error", err)
}
continue
}
// Connection successful
break
}
if err != nil {
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
}
p.client = client
p.config = cfg
if cfg.GetEnableLogging() {
logger.Info("MongoDB connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
}
return nil
}
// Close closes the MongoDB connection
func (p *MongoProvider) Close() error {
if p.client == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := p.client.Disconnect(ctx)
if err != nil {
return fmt.Errorf("failed to close MongoDB connection: %w", err)
}
if p.config.GetEnableLogging() {
logger.Info("MongoDB connection closed: name=%s", p.config.GetName())
}
p.client = nil
return nil
}
// HealthCheck verifies the MongoDB connection is alive
func (p *MongoProvider) HealthCheck(ctx context.Context) error {
if p.client == nil {
return fmt.Errorf("MongoDB client is nil")
}
// Use a short timeout for health checks
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := p.client.Ping(healthCtx, readpref.Primary()); err != nil {
return fmt.Errorf("health check failed: %w", err)
}
return nil
}
// GetNative returns an error for MongoDB (not a SQL database)
func (p *MongoProvider) GetNative() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// GetMongo returns the MongoDB client
func (p *MongoProvider) GetMongo() (*mongo.Client, error) {
if p.client == nil {
return nil, fmt.Errorf("MongoDB client is not initialized")
}
return p.client, nil
}
// Stats returns connection statistics for MongoDB
func (p *MongoProvider) Stats() *ConnectionStats {
if p.client == nil {
return &ConnectionStats{
Name: p.config.GetName(),
Type: "mongodb",
Connected: false,
}
}
// MongoDB doesn't expose detailed connection pool stats like sql.DB
// We return basic stats
return &ConnectionStats{
Name: p.config.GetName(),
Type: "mongodb",
Connected: true,
}
}
// parseReadPreference parses a read preference string into a readpref.ReadPref
func parseReadPreference(rp string) (*readpref.ReadPref, error) {
switch rp {
case "primary":
return readpref.Primary(), nil
case "primaryPreferred":
return readpref.PrimaryPreferred(), nil
case "secondary":
return readpref.Secondary(), nil
case "secondaryPreferred":
return readpref.SecondaryPreferred(), nil
case "nearest":
return readpref.Nearest(), nil
default:
return nil, fmt.Errorf("unknown read preference: %s", rp)
}
}

View File

@@ -0,0 +1,184 @@
package providers
import (
"context"
"database/sql"
"fmt"
"time"
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
"go.mongodb.org/mongo-driver/mongo"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// MSSQLProvider implements Provider for Microsoft SQL Server databases
type MSSQLProvider struct {
db *sql.DB
config ConnectionConfig
}
// NewMSSQLProvider creates a new MSSQL provider
func NewMSSQLProvider() *MSSQLProvider {
return &MSSQLProvider{}
}
// Connect establishes a MSSQL connection
func (p *MSSQLProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
// Build DSN
dsn, err := cfg.BuildDSN()
if err != nil {
return fmt.Errorf("failed to build DSN: %w", err)
}
// Connect with retry logic
var db *sql.DB
var lastErr error
retryAttempts := 3 // Default retry attempts
retryDelay := 1 * time.Second
for attempt := 0; attempt < retryAttempts; attempt++ {
if attempt > 0 {
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
if cfg.GetEnableLogging() {
logger.Info("Retrying MSSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
}
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
}
// Open database connection
db, err = sql.Open("sqlserver", dsn)
if err != nil {
lastErr = err
if cfg.GetEnableLogging() {
logger.Warn("Failed to open MSSQL connection", "error", err)
}
continue
}
// Test the connection with context timeout
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
err = db.PingContext(connectCtx)
cancel()
if err != nil {
lastErr = err
db.Close()
if cfg.GetEnableLogging() {
logger.Warn("Failed to ping MSSQL database", "error", err)
}
continue
}
// Connection successful
break
}
if err != nil {
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
}
// Configure connection pool
if cfg.GetMaxOpenConns() != nil {
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
}
if cfg.GetMaxIdleConns() != nil {
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
}
if cfg.GetConnMaxLifetime() != nil {
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
}
if cfg.GetConnMaxIdleTime() != nil {
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
}
p.db = db
p.config = cfg
if cfg.GetEnableLogging() {
logger.Info("MSSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
}
return nil
}
// Close closes the MSSQL connection
func (p *MSSQLProvider) Close() error {
if p.db == nil {
return nil
}
err := p.db.Close()
if err != nil {
return fmt.Errorf("failed to close MSSQL connection: %w", err)
}
if p.config.GetEnableLogging() {
logger.Info("MSSQL connection closed: name=%s", p.config.GetName())
}
p.db = nil
return nil
}
// HealthCheck verifies the MSSQL connection is alive
func (p *MSSQLProvider) HealthCheck(ctx context.Context) error {
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
// Use a short timeout for health checks
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := p.db.PingContext(healthCtx); err != nil {
return fmt.Errorf("health check failed: %w", err)
}
return nil
}
// GetNative returns the native *sql.DB connection
func (p *MSSQLProvider) GetNative() (*sql.DB, error) {
if p.db == nil {
return nil, fmt.Errorf("database connection is not initialized")
}
return p.db, nil
}
// GetMongo returns an error for MSSQL (not a MongoDB connection)
func (p *MSSQLProvider) GetMongo() (*mongo.Client, error) {
return nil, ErrNotMongoDB
}
// Stats returns connection pool statistics
func (p *MSSQLProvider) Stats() *ConnectionStats {
if p.db == nil {
return &ConnectionStats{
Name: p.config.GetName(),
Type: "mssql",
Connected: false,
}
}
stats := p.db.Stats()
return &ConnectionStats{
Name: p.config.GetName(),
Type: "mssql",
Connected: true,
OpenConnections: stats.OpenConnections,
InUse: stats.InUse,
Idle: stats.Idle,
WaitCount: stats.WaitCount,
WaitDuration: stats.WaitDuration,
MaxIdleClosed: stats.MaxIdleClosed,
MaxLifetimeClosed: stats.MaxLifetimeClosed,
}
}

View File

@@ -0,0 +1,231 @@
package providers
import (
"context"
"database/sql"
"fmt"
"math"
"sync"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"go.mongodb.org/mongo-driver/mongo"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// PostgresProvider implements Provider for PostgreSQL databases
type PostgresProvider struct {
db *sql.DB
config ConnectionConfig
listener *PostgresListener
mu sync.Mutex
}
// NewPostgresProvider creates a new PostgreSQL provider
func NewPostgresProvider() *PostgresProvider {
return &PostgresProvider{}
}
// Connect establishes a PostgreSQL connection
func (p *PostgresProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
// Build DSN
dsn, err := cfg.BuildDSN()
if err != nil {
return fmt.Errorf("failed to build DSN: %w", err)
}
// Connect with retry logic
var db *sql.DB
var lastErr error
retryAttempts := 3 // Default retry attempts
retryDelay := 1 * time.Second
for attempt := 0; attempt < retryAttempts; attempt++ {
if attempt > 0 {
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
if cfg.GetEnableLogging() {
logger.Info("Retrying PostgreSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
}
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
}
// Open database connection
db, err = sql.Open("pgx", dsn)
if err != nil {
lastErr = err
if cfg.GetEnableLogging() {
logger.Warn("Failed to open PostgreSQL connection", "error", err)
}
continue
}
// Test the connection with context timeout
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
err = db.PingContext(connectCtx)
cancel()
if err != nil {
lastErr = err
db.Close()
if cfg.GetEnableLogging() {
logger.Warn("Failed to ping PostgreSQL database", "error", err)
}
continue
}
// Connection successful
break
}
if err != nil {
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
}
// Configure connection pool
if cfg.GetMaxOpenConns() != nil {
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
}
if cfg.GetMaxIdleConns() != nil {
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
}
if cfg.GetConnMaxLifetime() != nil {
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
}
if cfg.GetConnMaxIdleTime() != nil {
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
}
p.db = db
p.config = cfg
if cfg.GetEnableLogging() {
logger.Info("PostgreSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
}
return nil
}
// Close closes the PostgreSQL connection
func (p *PostgresProvider) Close() error {
// Close listener if it exists
p.mu.Lock()
if p.listener != nil {
if err := p.listener.Close(); err != nil {
p.mu.Unlock()
return fmt.Errorf("failed to close listener: %w", err)
}
p.listener = nil
}
p.mu.Unlock()
if p.db == nil {
return nil
}
err := p.db.Close()
if err != nil {
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
}
if p.config.GetEnableLogging() {
logger.Info("PostgreSQL connection closed: name=%s", p.config.GetName())
}
p.db = nil
return nil
}
// HealthCheck verifies the PostgreSQL connection is alive
func (p *PostgresProvider) HealthCheck(ctx context.Context) error {
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
// Use a short timeout for health checks
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := p.db.PingContext(healthCtx); err != nil {
return fmt.Errorf("health check failed: %w", err)
}
return nil
}
// GetNative returns the native *sql.DB connection
func (p *PostgresProvider) GetNative() (*sql.DB, error) {
if p.db == nil {
return nil, fmt.Errorf("database connection is not initialized")
}
return p.db, nil
}
// GetMongo returns an error for PostgreSQL (not a MongoDB connection)
func (p *PostgresProvider) GetMongo() (*mongo.Client, error) {
return nil, ErrNotMongoDB
}
// Stats returns connection pool statistics
func (p *PostgresProvider) Stats() *ConnectionStats {
if p.db == nil {
return &ConnectionStats{
Name: p.config.GetName(),
Type: "postgres",
Connected: false,
}
}
stats := p.db.Stats()
return &ConnectionStats{
Name: p.config.GetName(),
Type: "postgres",
Connected: true,
OpenConnections: stats.OpenConnections,
InUse: stats.InUse,
Idle: stats.Idle,
WaitCount: stats.WaitCount,
WaitDuration: stats.WaitDuration,
MaxIdleClosed: stats.MaxIdleClosed,
MaxLifetimeClosed: stats.MaxLifetimeClosed,
}
}
// GetListener returns a PostgreSQL listener for NOTIFY/LISTEN functionality
// The listener is lazily initialized on first call and reused for subsequent calls
func (p *PostgresProvider) GetListener(ctx context.Context) (*PostgresListener, error) {
p.mu.Lock()
defer p.mu.Unlock()
// Return existing listener if already created
if p.listener != nil {
return p.listener, nil
}
// Create new listener
listener := NewPostgresListener(p.config)
// Connect the listener
if err := listener.Connect(ctx); err != nil {
return nil, fmt.Errorf("failed to connect listener: %w", err)
}
p.listener = listener
return p.listener, nil
}
// calculateBackoff calculates exponential backoff delay
func calculateBackoff(attempt int, initial, maxDelay time.Duration) time.Duration {
delay := initial * time.Duration(math.Pow(2, float64(attempt)))
if delay > maxDelay {
delay = maxDelay
}
return delay
}

View File

@@ -0,0 +1,401 @@
package providers
import (
"context"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// NotificationHandler is called when a notification is received
type NotificationHandler func(channel string, payload string)
// PostgresListener manages PostgreSQL LISTEN/NOTIFY functionality
type PostgresListener struct {
config ConnectionConfig
conn *pgx.Conn
// Channel subscriptions
channels map[string]NotificationHandler
mu sync.RWMutex
// Lifecycle management
ctx context.Context
cancel context.CancelFunc
closed bool
closeMu sync.Mutex
reconnectC chan struct{}
}
// NewPostgresListener creates a new PostgreSQL listener
func NewPostgresListener(cfg ConnectionConfig) *PostgresListener {
ctx, cancel := context.WithCancel(context.Background())
return &PostgresListener{
config: cfg,
channels: make(map[string]NotificationHandler),
ctx: ctx,
cancel: cancel,
reconnectC: make(chan struct{}, 1),
}
}
// Connect establishes a dedicated connection for listening
func (l *PostgresListener) Connect(ctx context.Context) error {
dsn, err := l.config.BuildDSN()
if err != nil {
return fmt.Errorf("failed to build DSN: %w", err)
}
// Parse connection config
connConfig, err := pgx.ParseConfig(dsn)
if err != nil {
return fmt.Errorf("failed to parse connection config: %w", err)
}
// Connect with retry logic
var conn *pgx.Conn
var lastErr error
retryAttempts := 3
retryDelay := 1 * time.Second
for attempt := 0; attempt < retryAttempts; attempt++ {
if attempt > 0 {
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
if l.config.GetEnableLogging() {
logger.Info("Retrying PostgreSQL listener connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
}
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
}
conn, err = pgx.ConnectConfig(ctx, connConfig)
if err != nil {
lastErr = err
if l.config.GetEnableLogging() {
logger.Warn("Failed to connect PostgreSQL listener", "error", err)
}
continue
}
// Test the connection
if err = conn.Ping(ctx); err != nil {
lastErr = err
conn.Close(ctx)
if l.config.GetEnableLogging() {
logger.Warn("Failed to ping PostgreSQL listener", "error", err)
}
continue
}
// Connection successful
break
}
if err != nil {
return fmt.Errorf("failed to connect listener after %d attempts: %w", retryAttempts, lastErr)
}
l.mu.Lock()
l.conn = conn
l.mu.Unlock()
// Start notification handler
go l.handleNotifications()
// Start reconnection handler
go l.handleReconnection()
if l.config.GetEnableLogging() {
logger.Info("PostgreSQL listener connected: name=%s", l.config.GetName())
}
return nil
}
// Listen subscribes to a PostgreSQL notification channel
func (l *PostgresListener) Listen(channel string, handler NotificationHandler) error {
l.closeMu.Lock()
if l.closed {
l.closeMu.Unlock()
return fmt.Errorf("listener is closed")
}
l.closeMu.Unlock()
l.mu.Lock()
defer l.mu.Unlock()
if l.conn == nil {
return fmt.Errorf("listener connection is not initialized")
}
// Execute LISTEN command
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("LISTEN %s", pgx.Identifier{channel}.Sanitize()))
if err != nil {
return fmt.Errorf("failed to listen on channel %s: %w", channel, err)
}
// Store the handler
l.channels[channel] = handler
if l.config.GetEnableLogging() {
logger.Info("Listening on channel: name=%s, channel=%s", l.config.GetName(), channel)
}
return nil
}
// Unlisten unsubscribes from a PostgreSQL notification channel
func (l *PostgresListener) Unlisten(channel string) error {
l.closeMu.Lock()
if l.closed {
l.closeMu.Unlock()
return fmt.Errorf("listener is closed")
}
l.closeMu.Unlock()
l.mu.Lock()
defer l.mu.Unlock()
if l.conn == nil {
return fmt.Errorf("listener connection is not initialized")
}
// Execute UNLISTEN command
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
if err != nil {
return fmt.Errorf("failed to unlisten from channel %s: %w", channel, err)
}
// Remove the handler
delete(l.channels, channel)
if l.config.GetEnableLogging() {
logger.Info("Unlistened from channel: name=%s, channel=%s", l.config.GetName(), channel)
}
return nil
}
// Notify sends a notification to a PostgreSQL channel
func (l *PostgresListener) Notify(ctx context.Context, channel string, payload string) error {
l.closeMu.Lock()
if l.closed {
l.closeMu.Unlock()
return fmt.Errorf("listener is closed")
}
l.closeMu.Unlock()
l.mu.RLock()
conn := l.conn
l.mu.RUnlock()
if conn == nil {
return fmt.Errorf("listener connection is not initialized")
}
// Execute NOTIFY command
_, err := conn.Exec(ctx, "SELECT pg_notify($1, $2)", channel, payload)
if err != nil {
return fmt.Errorf("failed to notify channel %s: %w", channel, err)
}
return nil
}
// Close closes the listener and all subscriptions
func (l *PostgresListener) Close() error {
l.closeMu.Lock()
if l.closed {
l.closeMu.Unlock()
return nil
}
l.closed = true
l.closeMu.Unlock()
// Cancel context to stop background goroutines
l.cancel()
l.mu.Lock()
defer l.mu.Unlock()
if l.conn == nil {
return nil
}
// Unlisten from all channels
for channel := range l.channels {
_, _ = l.conn.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
}
// Close connection
err := l.conn.Close(context.Background())
if err != nil {
return fmt.Errorf("failed to close listener connection: %w", err)
}
l.conn = nil
l.channels = make(map[string]NotificationHandler)
if l.config.GetEnableLogging() {
logger.Info("PostgreSQL listener closed: name=%s", l.config.GetName())
}
return nil
}
// handleNotifications processes incoming notifications
func (l *PostgresListener) handleNotifications() {
for {
select {
case <-l.ctx.Done():
return
default:
}
l.mu.RLock()
conn := l.conn
l.mu.RUnlock()
if conn == nil {
// Connection not available, wait for reconnection
time.Sleep(100 * time.Millisecond)
continue
}
// Wait for notification with timeout
ctx, cancel := context.WithTimeout(l.ctx, 5*time.Second)
notification, err := conn.WaitForNotification(ctx)
cancel()
if err != nil {
// Check if context was cancelled
if l.ctx.Err() != nil {
return
}
// Check if it's a connection error
if pgconn.Timeout(err) {
// Timeout is normal, continue waiting
continue
}
// Connection error, trigger reconnection
if l.config.GetEnableLogging() {
logger.Warn("Notification error, triggering reconnection", "error", err)
}
select {
case l.reconnectC <- struct{}{}:
default:
}
time.Sleep(1 * time.Second)
continue
}
// Process notification
l.mu.RLock()
handler, exists := l.channels[notification.Channel]
l.mu.RUnlock()
if exists && handler != nil {
// Call handler in a goroutine to avoid blocking
go func(ch, payload string) {
defer func() {
if r := recover(); r != nil {
if l.config.GetEnableLogging() {
logger.Error("Notification handler panic: channel=%s, error=%v", ch, r)
}
}
}()
handler(ch, payload)
}(notification.Channel, notification.Payload)
}
}
}
// handleReconnection manages automatic reconnection
func (l *PostgresListener) handleReconnection() {
for {
select {
case <-l.ctx.Done():
return
case <-l.reconnectC:
if l.config.GetEnableLogging() {
logger.Info("Attempting to reconnect listener: name=%s", l.config.GetName())
}
// Close existing connection
l.mu.Lock()
if l.conn != nil {
l.conn.Close(context.Background())
l.conn = nil
}
// Save current subscriptions
channels := make(map[string]NotificationHandler)
for ch, handler := range l.channels {
channels[ch] = handler
}
l.mu.Unlock()
// Attempt reconnection
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
err := l.Connect(ctx)
cancel()
if err != nil {
if l.config.GetEnableLogging() {
logger.Error("Failed to reconnect listener: name=%s, error=%v", l.config.GetName(), err)
}
// Retry after delay
time.Sleep(5 * time.Second)
select {
case l.reconnectC <- struct{}{}:
default:
}
continue
}
// Resubscribe to all channels
for channel, handler := range channels {
if err := l.Listen(channel, handler); err != nil {
if l.config.GetEnableLogging() {
logger.Error("Failed to resubscribe to channel: name=%s, channel=%s, error=%v", l.config.GetName(), channel, err)
}
}
}
if l.config.GetEnableLogging() {
logger.Info("Listener reconnected successfully: name=%s", l.config.GetName())
}
}
}
}
// IsConnected returns true if the listener is connected
func (l *PostgresListener) IsConnected() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return l.conn != nil
}
// Channels returns the list of channels currently being listened to
func (l *PostgresListener) Channels() []string {
l.mu.RLock()
defer l.mu.RUnlock()
channels := make([]string, 0, len(l.channels))
for ch := range l.channels {
channels = append(channels, ch)
}
return channels
}

View File

@@ -0,0 +1,228 @@
package providers_test
import (
"context"
"fmt"
"time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
)
// ExamplePostgresListener_basic demonstrates basic LISTEN/NOTIFY usage
func ExamplePostgresListener_basic() {
// Create a connection config
cfg := &dbmanager.ConnectionConfig{
Name: "example",
Type: dbmanager.DatabaseTypePostgreSQL,
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "password",
Database: "testdb",
ConnectTimeout: 10 * time.Second,
EnableLogging: true,
}
// Create and connect PostgreSQL provider
provider := providers.NewPostgresProvider()
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
}
defer provider.Close()
// Get listener
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
}
// Subscribe to a channel with a handler
err = listener.Listen("user_events", func(channel, payload string) {
fmt.Printf("Received notification on %s: %s\n", channel, payload)
})
if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err))
}
// Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err))
}
// Wait for notification to be processed
time.Sleep(100 * time.Millisecond)
// Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err))
}
}
// ExamplePostgresListener_multipleChannels demonstrates listening to multiple channels
func ExamplePostgresListener_multipleChannels() {
cfg := &dbmanager.ConnectionConfig{
Name: "example",
Type: dbmanager.DatabaseTypePostgreSQL,
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "password",
Database: "testdb",
ConnectTimeout: 10 * time.Second,
EnableLogging: false,
}
provider := providers.NewPostgresProvider()
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
}
defer provider.Close()
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
}
// Listen to multiple channels
channels := []string{"orders", "payments", "notifications"}
for _, ch := range channels {
channel := ch // Capture for closure
err := listener.Listen(channel, func(ch, payload string) {
fmt.Printf("[%s] %s\n", ch, payload)
})
if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
}
}
// Send notifications to different channels
listener.Notify(ctx, "orders", "New order #12345")
listener.Notify(ctx, "payments", "Payment received $99.99")
listener.Notify(ctx, "notifications", "Welcome email sent")
// Wait for notifications
time.Sleep(200 * time.Millisecond)
// Check active channels
activeChannels := listener.Channels()
fmt.Printf("Listening to %d channels: %v\n", len(activeChannels), activeChannels)
}
// ExamplePostgresListener_withDBManager demonstrates usage with DBManager
func ExamplePostgresListener_withDBManager() {
// This example shows how to use the listener with the full DBManager
// Assume we have a DBManager instance and get a connection
// conn, _ := dbMgr.Get("primary")
// Get the underlying provider (this would need to be exposed via the Connection interface)
// For now, this is a conceptual example
ctx := context.Background()
// Create provider directly for demonstration
cfg := &dbmanager.ConnectionConfig{
Name: "primary",
Type: dbmanager.DatabaseTypePostgreSQL,
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "password",
Database: "myapp",
ConnectTimeout: 10 * time.Second,
}
provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil {
panic(err)
}
defer provider.Close()
// Get listener
listener, err := provider.GetListener(ctx)
if err != nil {
panic(err)
}
// Subscribe to application events
listener.Listen("cache_invalidation", func(channel, payload string) {
fmt.Printf("Cache invalidation request: %s\n", payload)
// Handle cache invalidation logic here
})
listener.Listen("config_reload", func(channel, payload string) {
fmt.Printf("Configuration reload request: %s\n", payload)
// Handle configuration reload logic here
})
// Simulate receiving notifications
listener.Notify(ctx, "cache_invalidation", "user:123")
listener.Notify(ctx, "config_reload", "database")
time.Sleep(100 * time.Millisecond)
}
// ExamplePostgresListener_errorHandling demonstrates error handling and reconnection
func ExamplePostgresListener_errorHandling() {
cfg := &dbmanager.ConnectionConfig{
Name: "example",
Type: dbmanager.DatabaseTypePostgreSQL,
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "password",
Database: "testdb",
ConnectTimeout: 10 * time.Second,
EnableLogging: true,
}
provider := providers.NewPostgresProvider()
ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err))
}
defer provider.Close()
listener, err := provider.GetListener(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err))
}
// The listener automatically reconnects if the connection is lost
// Subscribe with error handling in the callback
err = listener.Listen("critical_events", func(channel, payload string) {
defer func() {
if r := recover(); r != nil {
fmt.Printf("Handler panic recovered: %v\n", r)
}
}()
// Process the event
fmt.Printf("Processing critical event: %s\n", payload)
// If processing fails, the panic will be caught by the defer above
// The listener will continue to function normally
})
if err != nil {
fmt.Printf("Failed to listen: %v\n", err)
return
}
// Check if listener is connected
if listener.IsConnected() {
fmt.Println("Listener is connected and ready")
}
// Send a notification
listener.Notify(ctx, "critical_events", "system_alert")
time.Sleep(100 * time.Millisecond)
}

View File

@@ -0,0 +1,83 @@
package providers
import (
"context"
"database/sql"
"errors"
"time"
"go.mongodb.org/mongo-driver/mongo"
)
// Common errors
var (
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
ErrNotSQLDatabase = errors.New("not a SQL database")
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
ErrNotMongoDB = errors.New("not a MongoDB connection")
)
// ConnectionStats contains statistics about a database connection
type ConnectionStats struct {
Name string
Type string // Database type as string to avoid circular dependency
Connected bool
LastHealthCheck time.Time
HealthCheckStatus string
// SQL connection pool stats
OpenConnections int
InUse int
Idle int
WaitCount int64
WaitDuration time.Duration
MaxIdleClosed int64
MaxLifetimeClosed int64
}
// ConnectionConfig is a minimal interface for configuration
// The actual implementation is in dbmanager package
type ConnectionConfig interface {
BuildDSN() (string, error)
GetName() string
GetType() string
GetHost() string
GetPort() int
GetUser() string
GetPassword() string
GetDatabase() string
GetFilePath() string
GetConnectTimeout() time.Duration
GetQueryTimeout() time.Duration
GetEnableLogging() bool
GetEnableMetrics() bool
GetMaxOpenConns() *int
GetMaxIdleConns() *int
GetConnMaxLifetime() *time.Duration
GetConnMaxIdleTime() *time.Duration
GetReadPreference() string
}
// Provider creates and manages the underlying database connection
type Provider interface {
// Connect establishes the database connection
Connect(ctx context.Context, cfg ConnectionConfig) error
// Close closes the connection
Close() error
// HealthCheck verifies the connection is alive
HealthCheck(ctx context.Context) error
// GetNative returns the native *sql.DB (SQL databases only)
// Returns an error for non-SQL databases
GetNative() (*sql.DB, error)
// GetMongo returns the MongoDB client (MongoDB only)
// Returns an error for non-MongoDB databases
GetMongo() (*mongo.Client, error)
// Stats returns connection statistics
Stats() *ConnectionStats
}

View File

@@ -0,0 +1,177 @@
package providers
import (
"context"
"database/sql"
"fmt"
"time"
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
"go.mongodb.org/mongo-driver/mongo"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// SQLiteProvider implements Provider for SQLite databases
type SQLiteProvider struct {
db *sql.DB
config ConnectionConfig
}
// NewSQLiteProvider creates a new SQLite provider
func NewSQLiteProvider() *SQLiteProvider {
return &SQLiteProvider{}
}
// Connect establishes a SQLite connection
func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
// Build DSN
dsn, err := cfg.BuildDSN()
if err != nil {
return fmt.Errorf("failed to build DSN: %w", err)
}
// Open database connection
db, err := sql.Open("sqlite", dsn)
if err != nil {
return fmt.Errorf("failed to open SQLite connection: %w", err)
}
// Test the connection with context timeout
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
err = db.PingContext(connectCtx)
cancel()
if err != nil {
db.Close()
return fmt.Errorf("failed to ping SQLite database: %w", err)
}
// Configure connection pool
// Note: SQLite works best with MaxOpenConns=1 for write operations
// but can handle multiple readers
if cfg.GetMaxOpenConns() != nil {
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
} else {
// Default to 1 for SQLite to avoid "database is locked" errors
db.SetMaxOpenConns(1)
}
if cfg.GetMaxIdleConns() != nil {
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
}
if cfg.GetConnMaxLifetime() != nil {
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
}
if cfg.GetConnMaxIdleTime() != nil {
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
}
// Enable WAL mode for better concurrent access
_, err = db.ExecContext(ctx, "PRAGMA journal_mode=WAL")
if err != nil {
if cfg.GetEnableLogging() {
logger.Warn("Failed to enable WAL mode for SQLite", "error", err)
}
// Don't fail connection if WAL mode cannot be enabled
}
// Set busy timeout to handle locked database
_, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000")
if err != nil {
if cfg.GetEnableLogging() {
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
}
}
p.db = db
p.config = cfg
if cfg.GetEnableLogging() {
logger.Info("SQLite connection established: name=%s, filepath=%s", cfg.GetName(), cfg.GetFilePath())
}
return nil
}
// Close closes the SQLite connection
func (p *SQLiteProvider) Close() error {
if p.db == nil {
return nil
}
err := p.db.Close()
if err != nil {
return fmt.Errorf("failed to close SQLite connection: %w", err)
}
if p.config.GetEnableLogging() {
logger.Info("SQLite connection closed: name=%s", p.config.GetName())
}
p.db = nil
return nil
}
// HealthCheck verifies the SQLite connection is alive
func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
if p.db == nil {
return fmt.Errorf("database connection is nil")
}
// Use a short timeout for health checks
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Execute a simple query to verify the database is accessible
var result int
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
if err != nil {
return fmt.Errorf("health check failed: %w", err)
}
if result != 1 {
return fmt.Errorf("health check returned unexpected result: %d", result)
}
return nil
}
// GetNative returns the native *sql.DB connection
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
if p.db == nil {
return nil, fmt.Errorf("database connection is not initialized")
}
return p.db, nil
}
// GetMongo returns an error for SQLite (not a MongoDB connection)
func (p *SQLiteProvider) GetMongo() (*mongo.Client, error) {
return nil, ErrNotMongoDB
}
// Stats returns connection pool statistics
func (p *SQLiteProvider) Stats() *ConnectionStats {
if p.db == nil {
return &ConnectionStats{
Name: p.config.GetName(),
Type: "sqlite",
Connected: false,
}
}
stats := p.db.Stats()
return &ConnectionStats{
Name: p.config.GetName(),
Type: "sqlite",
Connected: true,
OpenConnections: stats.OpenConnections,
InUse: stats.InUse,
Idle: stats.Idle,
WaitCount: stats.WaitCount,
WaitDuration: stats.WaitDuration,
MaxIdleClosed: stats.MaxIdleClosed,
MaxLifetimeClosed: stats.MaxLifetimeClosed,
}
}

View File

@@ -90,12 +90,12 @@ Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
defer logger.CatchPanic("MyFunction")()
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
})()
// Using HandlePanic
defer func() {

View File

@@ -0,0 +1,353 @@
# Event Broker System Implementation Plan
## Overview
Implement a comprehensive event handler/broker system for ResolveSpec that follows existing architectural patterns (Provider interface, Hook system, Config management, Graceful shutdown).
## Requirements Met
- ✅ Events with sources (database, websocket, frontend, system)
- ✅ Event statuses (pending, processing, completed, failed)
- ✅ Timestamps, JSON payloads, user IDs, session IDs
- ✅ Program instance IDs for tracking server instances
- ✅ Both sync and async processing modes
- ✅ Multiple provider backends (in-memory, Redis, NATS, database)
- ✅ Cross-instance pub/sub support
## Architecture
### Core Components
**Event Structure** (with full metadata):
```go
type Event struct {
ID string // UUID
Source EventSource // database, websocket, system, frontend
Type string // Pattern: schema.entity.operation
Status EventStatus // pending, processing, completed, failed
Payload json.RawMessage // JSON payload
UserID int
SessionID string
InstanceID string // Server instance identifier
Schema string
Entity string
Operation string // create, update, delete, read
CreatedAt time.Time
ProcessedAt *time.Time
CompletedAt *time.Time
Error string
Metadata map[string]interface{}
RetryCount int
}
```
**Provider Pattern** (like cache.Provider):
```go
type Provider interface {
Store(ctx context.Context, event *Event) error
Get(ctx context.Context, id string) (*Event, error)
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
UpdateStatus(ctx context.Context, id string, status EventStatus, error string) error
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
Publish(ctx context.Context, event *Event) error
Close() error
Stats(ctx context.Context) (*ProviderStats, error)
}
```
**Broker Interface**:
```go
type Broker interface {
Publish(ctx context.Context, event *Event) error // Mode-dependent
PublishSync(ctx context.Context, event *Event) error // Blocks
PublishAsync(ctx context.Context, event *Event) error // Non-blocking
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
Unsubscribe(id SubscriptionID) error
Start(ctx context.Context) error
Stop(ctx context.Context) error
Stats(ctx context.Context) (*BrokerStats, error)
}
```
## Implementation Steps
### Phase 1: Core Foundation (Files: 1-5)
**1. Create `pkg/eventbroker/event.go`**
- Event struct with all required fields (status, timestamps, user, instance ID, etc.)
- EventSource enum (database, websocket, frontend, system, internal)
- EventStatus enum (pending, processing, completed, failed)
- Helper: `EventType(schema, entity, operation string) string`
- Helper: `NewEvent()` constructor with UUID generation
**2. Create `pkg/eventbroker/provider.go`**
- Provider interface definition
- EventFilter struct for queries
- ProviderStats struct
**3. Create `pkg/eventbroker/handler.go`**
- EventHandler interface
- EventHandlerFunc adapter type
**4. Create `pkg/eventbroker/broker.go`**
- Broker interface definition
- EventBroker struct implementation
- ProcessingMode enum (sync, async)
- Options struct with functional options (WithProvider, WithMode, WithWorkerCount, etc.)
- NewBroker() constructor
- Sync processing implementation
**5. Create `pkg/eventbroker/subscription.go`**
- Pattern matching using glob syntax (e.g., "public.users.*", "*.*.create")
- subscriptionManager struct
- SubscriptionID type
- Subscribe/Unsubscribe logic
### Phase 2: Configuration & Integration (Files: 6-8)
**6. Create `pkg/eventbroker/config.go`**
- EventBrokerConfig struct
- RedisConfig, NATSConfig, DatabaseConfig structs
- RetryPolicyConfig struct
**7. Update `pkg/config/config.go`**
- Add `EventBroker EventBrokerConfig` field to Config struct
**8. Update `pkg/config/manager.go`**
- Add event broker defaults to `setDefaults()`:
```go
v.SetDefault("event_broker.enabled", false)
v.SetDefault("event_broker.provider", "memory")
v.SetDefault("event_broker.mode", "async")
v.SetDefault("event_broker.worker_count", 10)
v.SetDefault("event_broker.buffer_size", 1000)
```
### Phase 3: Memory Provider (Files: 9)
**9. Create `pkg/eventbroker/provider_memory.go`**
- MemoryProvider struct with mutex-protected map
- In-memory event storage
- Pattern matching for subscriptions
- Channel-based streaming for real-time events
- LRU eviction when max size reached
- Cleanup goroutine for old completed events
- **Note**: Single-instance only (no cross-instance pub/sub)
### Phase 4: Async Processing (Update File: 4)
**10. Update `pkg/eventbroker/broker.go`** (add async support)
- workerPool struct with configurable worker count
- Buffered channel for event queue
- Worker goroutines that process events
- PublishAsync() queues to channel
- Graceful shutdown: stop accepting events, drain queue, wait for workers
- Retry logic with exponential backoff
### Phase 5: Hook Integration (Files: 11)
**11. Create `pkg/eventbroker/hooks.go`**
- `RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry)`
- Registers AfterCreate, AfterUpdate, AfterDelete, AfterRead hooks
- Extracts UserContext from hook context
- Creates Event with proper metadata
- Calls `broker.PublishAsync()` to not block CRUD operations
### Phase 6: Global Singleton & Factory (Files: 12-13)
**12. Create `pkg/eventbroker/eventbroker.go`**
- Global `defaultBroker` variable
- `Initialize(config *config.Config) error` - creates broker from config
- `SetDefaultBroker(broker Broker)`
- `GetDefaultBroker() Broker`
- Helper functions: `Publish()`, `PublishAsync()`, `PublishSync()`, `Subscribe()`
- `RegisterShutdown(broker Broker)` - registers with server.RegisterShutdownCallback()
**13. Create `pkg/eventbroker/factory.go`**
- `NewProviderFromConfig(config EventBrokerConfig) (Provider, error)`
- Provider selection logic (memory, redis, nats, database)
- Returns appropriate provider based on config
### Phase 7: Redis Provider (Files: 14)
**14. Create `pkg/eventbroker/provider_redis.go`**
- RedisProvider using Redis Streams (XADD, XREAD, XGROUP)
- Consumer group for distributed processing
- Cross-instance pub/sub support
- Stream(pattern) subscribes to consumer group
- Publish() uses XADD to append to stream
- Graceful shutdown: acknowledge pending messages
**Dependencies**: `github.com/redis/go-redis/v9`
### Phase 8: NATS Provider (Files: 15)
**15. Create `pkg/eventbroker/provider_nats.go`**
- NATSProvider using NATS JetStream
- Subject-based routing: `events.{source}.{type}`
- Wildcard subscriptions support
- Durable consumers for replay
- At-least-once delivery semantics
**Dependencies**: `github.com/nats-io/nats.go`
### Phase 9: Database Provider (Files: 16)
**16. Create `pkg/eventbroker/provider_database.go`**
- DatabaseProvider using `common.Database` interface
- Table schema creation (events table with indexes)
- Polling-based event consumption (configurable interval)
- Full SQL query support via List(filter)
- Transaction support for atomic operations
- Good for audit trails and debugging
### Phase 10: Metrics Integration (Files: 17)
**17. Create `pkg/eventbroker/metrics.go`**
- Integrate with existing `metrics.Provider`
- Record metrics:
- `eventbroker_events_published_total{source, type}`
- `eventbroker_events_processed_total{source, type, status}`
- `eventbroker_event_processing_duration_seconds{source, type}`
- `eventbroker_queue_size`
- `eventbroker_workers_active`
**18. Update `pkg/metrics/interfaces.go`**
- Add methods to Provider interface:
```go
RecordEventPublished(source, eventType string)
RecordEventProcessed(source, eventType, status string, duration time.Duration)
UpdateEventQueueSize(size int64)
```
### Phase 11: Testing & Examples (Files: 19-20)
**19. Create `pkg/eventbroker/eventbroker_test.go`**
- Unit tests for Event marshaling
- Pattern matching tests
- MemoryProvider tests
- Sync vs async mode tests
- Concurrent publish/subscribe tests
- Graceful shutdown tests
**20. Create `pkg/eventbroker/example_usage.go`**
- Basic publish example
- Subscribe with patterns example
- Hook integration example
- Multiple handlers example
- Error handling example
## Integration Points
### Hook System Integration
```go
// In application initialization (e.g., main.go)
eventbroker.RegisterCRUDHooks(broker, handler.Hooks())
```
This automatically publishes events for all CRUD operations:
- `schema.entity.create` after inserts
- `schema.entity.update` after updates
- `schema.entity.delete` after deletes
- `schema.entity.read` after reads
### Shutdown Integration
```go
// In application initialization
eventbroker.RegisterShutdown(broker)
```
Ensures event broker flushes pending events before shutdown.
### Configuration Example
```yaml
event_broker:
enabled: true
provider: redis # memory, redis, nats, database
mode: async # sync, async
worker_count: 10
buffer_size: 1000
instance_id: "${HOSTNAME}"
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
host: "localhost"
port: 6379
```
## Usage Examples
### Publishing Custom Events
```go
// WebSocket event
event := &eventbroker.Event{
Source: eventbroker.EventSourceWebSocket,
Type: "chat.message",
Payload: json.RawMessage(`{"room": "lobby", "msg": "Hello"}`),
UserID: userID,
SessionID: sessionID,
}
eventbroker.PublishAsync(ctx, event)
```
### Subscribing to Events
```go
// Subscribe to all user creation events
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
log.Printf("New user created: %s", event.Payload)
// Send welcome email, update cache, etc.
return nil
},
))
// Subscribe to all events from database
eventbroker.Subscribe("*", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
if event.Source == eventbroker.EventSourceDatabase {
// Audit logging
}
return nil
},
))
```
## Critical Files Reference
**Patterns to Follow**:
- `pkg/cache/provider.go` - Provider interface pattern
- `pkg/restheadspec/hooks.go` - Hook system integration
- `pkg/config/manager.go` - Configuration pattern
- `pkg/server/shutdown.go` - Shutdown callbacks
**Files to Modify**:
- `pkg/config/config.go` - Add EventBroker field
- `pkg/config/manager.go` - Add defaults
- `pkg/metrics/interfaces.go` - Add event broker metrics
**New Package**:
- `pkg/eventbroker/` (20 files total)
## Provider Feature Matrix
| Feature | Memory | Redis | NATS | Database |
|---------|--------|-------|------|----------|
| Persistence | ❌ | ✅ | ✅ | ✅ |
| Cross-instance | ❌ | ✅ | ✅ | ✅ |
| Real-time | ✅ | ✅ | ✅ | ⚠️ (polling) |
| Query history | Limited | Limited | ✅ (replay) | ✅ (SQL) |
| External deps | None | Redis | NATS | None |
| Complexity | Low | Medium | Medium | Low |
## Implementation Order Priority
1. **Core + Memory Provider** (Phase 1-3) - Functional in-process event system
2. **Async + Hooks** (Phase 4-5) - Non-blocking event dispatch integrated with CRUD
3. **Config + Singleton** (Phase 6) - Easy initialization and usage
4. **Redis Provider** (Phase 7) - Production-ready distributed events
5. **Metrics** (Phase 10) - Observability
6. **NATS/Database** (Phase 8-9) - Alternative backends
7. **Tests + Examples** (Phase 11) - Documentation and reliability
## Next Steps
After approval, implement in order of phases. Each phase builds on previous phases and can be tested independently.

View File

@@ -172,12 +172,13 @@ event_broker:
provider: memory
```
### Redis Provider (Future)
### Redis Provider
Best for: Production, multi-instance deployments
- **Pros**: Persistent, cross-instance pub/sub, reliable
- **Cons**: Requires Redis
- **Pros**: Persistent, cross-instance pub/sub, reliable, Redis Streams support
- **Cons**: Requires Redis server
- **Status**: ✅ Implemented
```yaml
event_broker:
@@ -185,16 +186,20 @@ event_broker:
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
max_len: 10000
host: "localhost"
port: 6379
password: ""
db: 0
```
### NATS Provider (Future)
### NATS Provider
Best for: High-performance, low-latency requirements
- **Pros**: Very fast, built-in clustering, durable
- **Pros**: Very fast, built-in clustering, durable, JetStream support
- **Cons**: Requires NATS server
- **Status**: ✅ Implemented
```yaml
event_broker:
@@ -202,14 +207,17 @@ event_broker:
nats:
url: "nats://localhost:4222"
stream_name: "RESOLVESPEC_EVENTS"
storage: "file" # or "memory"
max_age: "24h"
```
### Database Provider (Future)
### Database Provider
Best for: Audit trails, event replay, SQL queries
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
- **Cons**: Slower than Redis/NATS
- **Cons**: Slower than Redis/NATS, requires database connection
- **Status**: ✅ Implemented
```yaml
event_broker:
@@ -217,6 +225,7 @@ event_broker:
database:
table_name: "events"
channel: "resolvespec_events"
poll_interval: "1s"
```
## Processing Modes
@@ -314,14 +323,25 @@ See `example_usage.go` for comprehensive examples including:
└─────────────────┘
```
## Implemented Features
- [x] Memory Provider (in-process, single-instance)
- [x] Redis Streams Provider (distributed, persistent)
- [x] NATS JetStream Provider (distributed, high-performance)
- [x] Database Provider with PostgreSQL NOTIFY (SQL-queryable, audit-friendly)
- [x] Sync and Async processing modes
- [x] Pattern-based subscriptions
- [x] Hook integration for automatic CRUD events
- [x] Retry policy with exponential backoff
- [x] Graceful shutdown
## Future Enhancements
- [ ] Database Provider with PostgreSQL NOTIFY
- [ ] Redis Streams Provider
- [ ] NATS JetStream Provider
- [ ] Event replay functionality
- [ ] Dead letter queue
- [ ] Event filtering at provider level
- [ ] Batch publishing
- [ ] Event compression
- [ ] Schema versioning
- [ ] Event replay functionality from specific timestamp
- [ ] Dead letter queue for failed events
- [ ] Event filtering at provider level for performance
- [ ] Batch publishing support
- [ ] Event compression for large payloads
- [ ] Schema versioning and migration
- [ ] Event streaming to external systems (Kafka, RabbitMQ)
- [ ] Event aggregation and analytics

View File

@@ -7,7 +7,6 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
var (
@@ -69,9 +68,6 @@ func Initialize(cfg config.EventBrokerConfig) error {
// Set as default
SetDefaultBroker(broker)
// Register shutdown callback
RegisterShutdown(broker)
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
cfg.Provider, cfg.Mode, opts.InstanceID)
@@ -151,10 +147,12 @@ func Stats(ctx context.Context) (*BrokerStats, error) {
return broker.Stats(ctx)
}
// RegisterShutdown registers the broker's shutdown with the server shutdown callbacks
func RegisterShutdown(broker Broker) {
server.RegisterShutdownCallback(func(ctx context.Context) error {
// RegisterShutdown registers the broker's shutdown with a server manager
// Call this from your application initialization code
// Example: serverMgr.RegisterShutdownCallback(eventbroker.MakeShutdownCallback(broker))
func MakeShutdownCallback(broker Broker) func(context.Context) error {
return func(ctx context.Context) error {
logger.Info("Shutting down event broker...")
return broker.Stop(ctx)
})
}
}

View File

@@ -24,16 +24,34 @@ func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
}), nil
case "redis":
// Redis provider will be implemented in Phase 8
return nil, fmt.Errorf("redis provider not yet implemented")
return NewRedisProvider(RedisProviderConfig{
Host: cfg.Redis.Host,
Port: cfg.Redis.Port,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
StreamName: cfg.Redis.StreamName,
ConsumerGroup: cfg.Redis.ConsumerGroup,
ConsumerName: getInstanceID(cfg.InstanceID),
InstanceID: getInstanceID(cfg.InstanceID),
MaxLen: cfg.Redis.MaxLen,
})
case "nats":
// NATS provider will be implemented in Phase 9
return nil, fmt.Errorf("nats provider not yet implemented")
// NATS provider initialization
// Note: Requires github.com/nats-io/nats.go dependency
return NewNATSProvider(NATSProviderConfig{
URL: cfg.NATS.URL,
StreamName: cfg.NATS.StreamName,
SubjectPrefix: "events",
InstanceID: getInstanceID(cfg.InstanceID),
MaxAge: cfg.NATS.MaxAge,
Storage: cfg.NATS.Storage, // "file" or "memory"
})
case "database":
// Database provider will be implemented in Phase 7
return nil, fmt.Errorf("database provider not yet implemented")
// Database provider requires a database connection
// This should be provided externally
return nil, fmt.Errorf("database provider requires a database connection to be configured separately")
default:
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)

View File

@@ -0,0 +1,653 @@
package eventbroker
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// DatabaseProvider implements Provider interface using SQL database
// Features:
// - Persistent event storage in database table
// - Full SQL query support for event history
// - PostgreSQL NOTIFY/LISTEN for real-time updates (optional)
// - Polling-based consumption with configurable interval
// - Good for audit trails and event replay
type DatabaseProvider struct {
db common.Database
tableName string
channel string // PostgreSQL NOTIFY channel name
pollInterval time.Duration
instanceID string
useNotify bool // Whether to use PostgreSQL NOTIFY
// Subscriptions
mu sync.RWMutex
subscribers map[string]*dbSubscription
// Statistics
stats DatabaseProviderStats
// Lifecycle
stopPolling chan struct{}
wg sync.WaitGroup
isRunning atomic.Bool
}
// DatabaseProviderStats contains statistics for the database provider
type DatabaseProviderStats struct {
TotalEvents atomic.Int64
EventsPublished atomic.Int64
EventsConsumed atomic.Int64
ActiveSubscribers atomic.Int32
PollErrors atomic.Int64
}
// dbSubscription represents a single database subscription
type dbSubscription struct {
pattern string
ch chan *Event
lastSeenID string
ctx context.Context
cancel context.CancelFunc
}
// DatabaseProviderConfig configures the database provider
type DatabaseProviderConfig struct {
DB common.Database
TableName string
Channel string // PostgreSQL NOTIFY channel (optional)
PollInterval time.Duration
InstanceID string
UseNotify bool // Enable PostgreSQL NOTIFY/LISTEN
}
// NewDatabaseProvider creates a new database event provider
func NewDatabaseProvider(cfg DatabaseProviderConfig) (*DatabaseProvider, error) {
// Apply defaults
if cfg.TableName == "" {
cfg.TableName = "events"
}
if cfg.Channel == "" {
cfg.Channel = "resolvespec_events"
}
if cfg.PollInterval == 0 {
cfg.PollInterval = 1 * time.Second
}
dp := &DatabaseProvider{
db: cfg.DB,
tableName: cfg.TableName,
channel: cfg.Channel,
pollInterval: cfg.PollInterval,
instanceID: cfg.InstanceID,
useNotify: cfg.UseNotify,
subscribers: make(map[string]*dbSubscription),
stopPolling: make(chan struct{}),
}
dp.isRunning.Store(true)
// Create table if it doesn't exist
ctx := context.Background()
if err := dp.createTable(ctx); err != nil {
return nil, fmt.Errorf("failed to create events table: %w", err)
}
// Start polling goroutine for subscriptions
dp.wg.Add(1)
go dp.pollLoop()
logger.Info("Database provider initialized (table: %s, poll_interval: %v, notify: %v)",
cfg.TableName, cfg.PollInterval, cfg.UseNotify)
return dp, nil
}
// Store stores an event
func (dp *DatabaseProvider) Store(ctx context.Context, event *Event) error {
// Marshal metadata to JSON
metadataJSON, err := json.Marshal(event.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
// Insert event
query := fmt.Sprintf(`
INSERT INTO %s (
id, source, type, status, retry_count, error,
payload, user_id, session_id, instance_id,
schema, entity, operation,
created_at, processed_at, completed_at, metadata
) VALUES (
$1, $2, $3, $4, $5, $6,
$7, $8, $9, $10,
$11, $12, $13,
$14, $15, $16, $17
)
`, dp.tableName)
_, err = dp.db.Exec(ctx, query,
event.ID, event.Source, event.Type, event.Status, event.RetryCount, event.Error,
event.Payload, event.UserID, event.SessionID, event.InstanceID,
event.Schema, event.Entity, event.Operation,
event.CreatedAt, event.ProcessedAt, event.CompletedAt, metadataJSON,
)
if err != nil {
return fmt.Errorf("failed to insert event: %w", err)
}
dp.stats.TotalEvents.Add(1)
return nil
}
// Get retrieves an event by ID
func (dp *DatabaseProvider) Get(ctx context.Context, id string) (*Event, error) {
event := &Event{}
var metadataJSON []byte
var processedAt, completedAt sql.NullTime
// Query into individual fields
query := fmt.Sprintf(`
SELECT id, source, type, status, retry_count, error,
payload, user_id, session_id, instance_id,
schema, entity, operation,
created_at, processed_at, completed_at, metadata
FROM %s
WHERE id = $1
`, dp.tableName)
var source, eventType, status, operation string
// Execute raw query
rows, err := dp.db.GetUnderlyingDB().(interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}).QueryContext(ctx, query, id)
if err != nil {
return nil, fmt.Errorf("failed to query event: %w", err)
}
defer rows.Close()
if !rows.Next() {
return nil, fmt.Errorf("event not found: %s", id)
}
if err := rows.Scan(
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
&event.Schema, &event.Entity, &operation,
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
); err != nil {
return nil, fmt.Errorf("failed to scan event: %w", err)
}
// Set enum values
event.Source = EventSource(source)
event.Type = eventType
event.Status = EventStatus(status)
event.Operation = operation
// Handle nullable timestamps
if processedAt.Valid {
event.ProcessedAt = &processedAt.Time
}
if completedAt.Valid {
event.CompletedAt = &completedAt.Time
}
// Unmarshal metadata
if len(metadataJSON) > 0 {
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
logger.Warn("Failed to unmarshal metadata: %v", err)
}
}
return event, nil
}
// List lists events with optional filters
func (dp *DatabaseProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
query := fmt.Sprintf("SELECT id, source, type, status, retry_count, error, "+
"payload, user_id, session_id, instance_id, "+
"schema, entity, operation, "+
"created_at, processed_at, completed_at, metadata "+
"FROM %s WHERE 1=1", dp.tableName)
args := []interface{}{}
argNum := 1
// Build WHERE clause
if filter != nil {
if filter.Source != nil {
query += fmt.Sprintf(" AND source = $%d", argNum)
args = append(args, string(*filter.Source))
argNum++
}
if filter.Status != nil {
query += fmt.Sprintf(" AND status = $%d", argNum)
args = append(args, string(*filter.Status))
argNum++
}
if filter.UserID != nil {
query += fmt.Sprintf(" AND user_id = $%d", argNum)
args = append(args, *filter.UserID)
argNum++
}
if filter.Schema != "" {
query += fmt.Sprintf(" AND schema = $%d", argNum)
args = append(args, filter.Schema)
argNum++
}
if filter.Entity != "" {
query += fmt.Sprintf(" AND entity = $%d", argNum)
args = append(args, filter.Entity)
argNum++
}
if filter.Operation != "" {
query += fmt.Sprintf(" AND operation = $%d", argNum)
args = append(args, filter.Operation)
argNum++
}
if filter.InstanceID != "" {
query += fmt.Sprintf(" AND instance_id = $%d", argNum)
args = append(args, filter.InstanceID)
argNum++
}
if filter.StartTime != nil {
query += fmt.Sprintf(" AND created_at >= $%d", argNum)
args = append(args, *filter.StartTime)
argNum++
}
if filter.EndTime != nil {
query += fmt.Sprintf(" AND created_at <= $%d", argNum)
args = append(args, *filter.EndTime)
argNum++
}
}
// Add ORDER BY
query += " ORDER BY created_at DESC"
// Add LIMIT and OFFSET
if filter != nil {
if filter.Limit > 0 {
query += fmt.Sprintf(" LIMIT $%d", argNum)
args = append(args, filter.Limit)
argNum++
}
if filter.Offset > 0 {
query += fmt.Sprintf(" OFFSET $%d", argNum)
args = append(args, filter.Offset)
}
}
// Execute query
rows, err := dp.db.GetUnderlyingDB().(interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}).QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query events: %w", err)
}
defer rows.Close()
var results []*Event
for rows.Next() {
event := &Event{}
var source, eventType, status, operation string
var metadataJSON []byte
var processedAt, completedAt sql.NullTime
err := rows.Scan(
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
&event.Schema, &event.Entity, &operation,
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
)
if err != nil {
logger.Warn("Failed to scan event: %v", err)
continue
}
// Set enum values
event.Source = EventSource(source)
event.Type = eventType
event.Status = EventStatus(status)
event.Operation = operation
// Handle nullable timestamps
if processedAt.Valid {
event.ProcessedAt = &processedAt.Time
}
if completedAt.Valid {
event.CompletedAt = &completedAt.Time
}
// Unmarshal metadata
if len(metadataJSON) > 0 {
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
logger.Warn("Failed to unmarshal metadata: %v", err)
}
}
results = append(results, event)
}
return results, nil
}
// UpdateStatus updates the status of an event
func (dp *DatabaseProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
query := fmt.Sprintf(`
UPDATE %s
SET status = $1, error = $2
WHERE id = $3
`, dp.tableName)
_, err := dp.db.Exec(ctx, query, string(status), errorMsg, id)
if err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
return nil
}
// Delete deletes an event by ID
func (dp *DatabaseProvider) Delete(ctx context.Context, id string) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = $1", dp.tableName)
_, err := dp.db.Exec(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete event: %w", err)
}
dp.stats.TotalEvents.Add(-1)
return nil
}
// Stream returns a channel of events for real-time consumption
func (dp *DatabaseProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
ch := make(chan *Event, 100)
subCtx, cancel := context.WithCancel(ctx)
sub := &dbSubscription{
pattern: pattern,
ch: ch,
lastSeenID: "",
ctx: subCtx,
cancel: cancel,
}
dp.mu.Lock()
dp.subscribers[pattern] = sub
dp.stats.ActiveSubscribers.Add(1)
dp.mu.Unlock()
return ch, nil
}
// Publish publishes an event to all subscribers
func (dp *DatabaseProvider) Publish(ctx context.Context, event *Event) error {
// Store the event first
if err := dp.Store(ctx, event); err != nil {
return err
}
dp.stats.EventsPublished.Add(1)
// If using PostgreSQL NOTIFY, send notification
if dp.useNotify {
if err := dp.notify(ctx, event.ID); err != nil {
logger.Warn("Failed to send NOTIFY: %v", err)
}
}
return nil
}
// Close closes the provider and releases resources
func (dp *DatabaseProvider) Close() error {
if !dp.isRunning.Load() {
return nil
}
dp.isRunning.Store(false)
// Cancel all subscriptions
dp.mu.Lock()
for _, sub := range dp.subscribers {
sub.cancel()
}
dp.mu.Unlock()
// Stop polling
close(dp.stopPolling)
// Wait for goroutines
dp.wg.Wait()
logger.Info("Database provider closed")
return nil
}
// Stats returns provider statistics
func (dp *DatabaseProvider) Stats(ctx context.Context) (*ProviderStats, error) {
// Get counts by status
query := fmt.Sprintf(`
SELECT
COUNT(*) FILTER (WHERE status = 'pending') as pending,
COUNT(*) FILTER (WHERE status = 'processing') as processing,
COUNT(*) FILTER (WHERE status = 'completed') as completed,
COUNT(*) FILTER (WHERE status = 'failed') as failed,
COUNT(*) as total
FROM %s
`, dp.tableName)
var pending, processing, completed, failed, total int64
rows, err := dp.db.GetUnderlyingDB().(interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}).QueryContext(ctx, query)
if err != nil {
logger.Warn("Failed to get stats: %v", err)
} else {
defer rows.Close()
if rows.Next() {
if err := rows.Scan(&pending, &processing, &completed, &failed, &total); err != nil {
logger.Warn("Failed to scan stats: %v", err)
}
}
}
return &ProviderStats{
ProviderType: "database",
TotalEvents: total,
PendingEvents: pending,
ProcessingEvents: processing,
CompletedEvents: completed,
FailedEvents: failed,
EventsPublished: dp.stats.EventsPublished.Load(),
EventsConsumed: dp.stats.EventsConsumed.Load(),
ActiveSubscribers: int(dp.stats.ActiveSubscribers.Load()),
ProviderSpecific: map[string]interface{}{
"table_name": dp.tableName,
"poll_interval": dp.pollInterval.String(),
"use_notify": dp.useNotify,
"poll_errors": dp.stats.PollErrors.Load(),
},
}, nil
}
// pollLoop periodically polls for new events
func (dp *DatabaseProvider) pollLoop() {
defer dp.wg.Done()
ticker := time.NewTicker(dp.pollInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
dp.pollEvents()
case <-dp.stopPolling:
return
}
}
}
// pollEvents polls for new events and delivers to subscribers
func (dp *DatabaseProvider) pollEvents() {
dp.mu.RLock()
subscribers := make([]*dbSubscription, 0, len(dp.subscribers))
for _, sub := range dp.subscribers {
subscribers = append(subscribers, sub)
}
dp.mu.RUnlock()
for _, sub := range subscribers {
// Query for new events since last seen
query := fmt.Sprintf(`
SELECT id, source, type, status, retry_count, error,
payload, user_id, session_id, instance_id,
schema, entity, operation,
created_at, processed_at, completed_at, metadata
FROM %s
WHERE id > $1
ORDER BY created_at ASC
LIMIT 100
`, dp.tableName)
lastSeenID := sub.lastSeenID
if lastSeenID == "" {
lastSeenID = "00000000-0000-0000-0000-000000000000"
}
rows, err := dp.db.GetUnderlyingDB().(interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}).QueryContext(sub.ctx, query, lastSeenID)
if err != nil {
dp.stats.PollErrors.Add(1)
logger.Warn("Failed to poll events: %v", err)
continue
}
for rows.Next() {
event := &Event{}
var source, eventType, status, operation string
var metadataJSON []byte
var processedAt, completedAt sql.NullTime
err := rows.Scan(
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
&event.Schema, &event.Entity, &operation,
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
)
if err != nil {
logger.Warn("Failed to scan event: %v", err)
continue
}
// Set enum values
event.Source = EventSource(source)
event.Type = eventType
event.Status = EventStatus(status)
event.Operation = operation
// Handle nullable timestamps
if processedAt.Valid {
event.ProcessedAt = &processedAt.Time
}
if completedAt.Valid {
event.CompletedAt = &completedAt.Time
}
// Unmarshal metadata
if len(metadataJSON) > 0 {
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
logger.Warn("Failed to unmarshal metadata: %v", err)
}
}
// Check if event matches pattern
if matchPattern(sub.pattern, event.Type) {
select {
case sub.ch <- event:
dp.stats.EventsConsumed.Add(1)
sub.lastSeenID = event.ID
case <-sub.ctx.Done():
rows.Close()
return
default:
// Channel full, skip
logger.Warn("Subscriber channel full for pattern: %s", sub.pattern)
}
}
sub.lastSeenID = event.ID
}
rows.Close()
}
}
// notify sends a PostgreSQL NOTIFY message
func (dp *DatabaseProvider) notify(ctx context.Context, eventID string) error {
query := fmt.Sprintf("NOTIFY %s, '%s'", dp.channel, eventID)
_, err := dp.db.Exec(ctx, query)
return err
}
// createTable creates the events table if it doesn't exist
func (dp *DatabaseProvider) createTable(ctx context.Context) error {
query := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id VARCHAR(255) PRIMARY KEY,
source VARCHAR(50) NOT NULL,
type VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL,
retry_count INTEGER DEFAULT 0,
error TEXT,
payload JSONB,
user_id INTEGER,
session_id VARCHAR(255),
instance_id VARCHAR(255),
schema VARCHAR(255),
entity VARCHAR(255),
operation VARCHAR(50),
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
processed_at TIMESTAMP,
completed_at TIMESTAMP,
metadata JSONB
)
`, dp.tableName)
if _, err := dp.db.Exec(ctx, query); err != nil {
return fmt.Errorf("failed to create table: %w", err)
}
// Create indexes
indexes := []string{
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_source ON %s(source)", dp.tableName, dp.tableName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_type ON %s(type)", dp.tableName, dp.tableName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_status ON %s(status)", dp.tableName, dp.tableName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_created_at ON %s(created_at)", dp.tableName, dp.tableName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_instance_id ON %s(instance_id)", dp.tableName, dp.tableName),
}
for _, indexQuery := range indexes {
if _, err := dp.db.Exec(ctx, indexQuery); err != nil {
logger.Warn("Failed to create index: %v", err)
}
}
return nil
}

View File

@@ -0,0 +1,565 @@
package eventbroker
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// NATSProvider implements Provider interface using NATS JetStream
// Features:
// - Persistent event storage using JetStream
// - Cross-instance pub/sub using NATS subjects
// - Wildcard subscription support
// - Durable consumers for event replay
// - At-least-once delivery semantics
type NATSProvider struct {
nc *nats.Conn
js jetstream.JetStream
stream jetstream.Stream
streamName string
subjectPrefix string
instanceID string
maxAge time.Duration
// Subscriptions
mu sync.RWMutex
subscribers map[string]*natsSubscription
// Statistics
stats NATSProviderStats
// Lifecycle
wg sync.WaitGroup
isRunning atomic.Bool
}
// NATSProviderStats contains statistics for the NATS provider
type NATSProviderStats struct {
TotalEvents atomic.Int64
EventsPublished atomic.Int64
EventsConsumed atomic.Int64
ActiveSubscribers atomic.Int32
ConsumerErrors atomic.Int64
}
// natsSubscription represents a single NATS subscription
type natsSubscription struct {
pattern string
consumer jetstream.Consumer
ch chan *Event
ctx context.Context
cancel context.CancelFunc
}
// NATSProviderConfig configures the NATS provider
type NATSProviderConfig struct {
URL string
StreamName string
SubjectPrefix string // e.g., "events"
InstanceID string
MaxAge time.Duration // How long to keep events
Storage string // "file" or "memory"
}
// NewNATSProvider creates a new NATS event provider
func NewNATSProvider(cfg NATSProviderConfig) (*NATSProvider, error) {
// Apply defaults
if cfg.URL == "" {
cfg.URL = nats.DefaultURL
}
if cfg.StreamName == "" {
cfg.StreamName = "RESOLVESPEC_EVENTS"
}
if cfg.SubjectPrefix == "" {
cfg.SubjectPrefix = "events"
}
if cfg.MaxAge == 0 {
cfg.MaxAge = 7 * 24 * time.Hour // 7 days
}
if cfg.Storage == "" {
cfg.Storage = "file"
}
// Connect to NATS
nc, err := nats.Connect(cfg.URL,
nats.Name("resolvespec-eventbroker-"+cfg.InstanceID),
nats.Timeout(5*time.Second),
)
if err != nil {
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
}
// Create JetStream context
js, err := jetstream.New(nc)
if err != nil {
nc.Close()
return nil, fmt.Errorf("failed to create JetStream context: %w", err)
}
np := &NATSProvider{
nc: nc,
js: js,
streamName: cfg.StreamName,
subjectPrefix: cfg.SubjectPrefix,
instanceID: cfg.InstanceID,
maxAge: cfg.MaxAge,
subscribers: make(map[string]*natsSubscription),
}
np.isRunning.Store(true)
// Create or update stream
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Determine storage type
var storage jetstream.StorageType
if cfg.Storage == "memory" {
storage = jetstream.MemoryStorage
} else {
storage = jetstream.FileStorage
}
if err := np.ensureStream(ctx, storage); err != nil {
nc.Close()
return nil, fmt.Errorf("failed to create stream: %w", err)
}
logger.Info("NATS provider initialized (stream: %s, subject: %s.*, url: %s)",
cfg.StreamName, cfg.SubjectPrefix, cfg.URL)
return np, nil
}
// Store stores an event
func (np *NATSProvider) Store(ctx context.Context, event *Event) error {
// Marshal event to JSON
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
// Publish to NATS subject
// Subject format: events.{source}.{schema}.{entity}.{operation}
subject := np.buildSubject(event)
msg := &nats.Msg{
Subject: subject,
Data: data,
Header: nats.Header{
"Event-ID": []string{event.ID},
"Event-Type": []string{event.Type},
"Event-Source": []string{string(event.Source)},
"Event-Status": []string{string(event.Status)},
"Instance-ID": []string{event.InstanceID},
},
}
if _, err := np.js.PublishMsg(ctx, msg); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}
np.stats.TotalEvents.Add(1)
return nil
}
// Get retrieves an event by ID
// Note: This is inefficient with JetStream - consider using a separate KV store for lookups
func (np *NATSProvider) Get(ctx context.Context, id string) (*Event, error) {
// We need to scan messages which is not ideal
// For production, consider using NATS KV store for fast lookups
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
Name: "get-" + id,
FilterSubject: np.subjectPrefix + ".>",
DeliverPolicy: jetstream.DeliverAllPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
})
if err != nil {
return nil, fmt.Errorf("failed to create consumer: %w", err)
}
// Fetch messages in batches
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
if err != nil {
return nil, fmt.Errorf("failed to fetch messages: %w", err)
}
for msg := range msgs.Messages() {
if msg.Headers().Get("Event-ID") == id {
var event Event
if err := json.Unmarshal(msg.Data(), &event); err != nil {
_ = msg.Nak()
continue
}
_ = msg.Ack()
// Delete temporary consumer
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
return &event, nil
}
_ = msg.Ack()
}
// Delete temporary consumer
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
return nil, fmt.Errorf("event not found: %s", id)
}
// List lists events with optional filters
func (np *NATSProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
var results []*Event
// Create temporary consumer
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
Name: fmt.Sprintf("list-%d", time.Now().UnixNano()),
FilterSubject: np.subjectPrefix + ".>",
DeliverPolicy: jetstream.DeliverAllPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
})
if err != nil {
return nil, fmt.Errorf("failed to create consumer: %w", err)
}
defer func() { _ = np.stream.DeleteConsumer(ctx, consumer.CachedInfo().Name) }()
// Fetch messages in batches
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
if err != nil {
return nil, fmt.Errorf("failed to fetch messages: %w", err)
}
for msg := range msgs.Messages() {
var event Event
if err := json.Unmarshal(msg.Data(), &event); err != nil {
logger.Warn("Failed to unmarshal event: %v", err)
_ = msg.Nak()
continue
}
if np.matchesFilter(&event, filter) {
results = append(results, &event)
}
_ = msg.Ack()
}
// Apply limit and offset
if filter != nil {
if filter.Offset > 0 && filter.Offset < len(results) {
results = results[filter.Offset:]
}
if filter.Limit > 0 && filter.Limit < len(results) {
results = results[:filter.Limit]
}
}
return results, nil
}
// UpdateStatus updates the status of an event
// Note: NATS streams are append-only, so we publish a status update event
func (np *NATSProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
// Publish a status update message
subject := fmt.Sprintf("%s.status.%s", np.subjectPrefix, id)
statusUpdate := map[string]interface{}{
"event_id": id,
"status": string(status),
"error": errorMsg,
"updated_at": time.Now(),
}
data, err := json.Marshal(statusUpdate)
if err != nil {
return fmt.Errorf("failed to marshal status update: %w", err)
}
if _, err := np.js.Publish(ctx, subject, data); err != nil {
return fmt.Errorf("failed to publish status update: %w", err)
}
return nil
}
// Delete deletes an event by ID
// Note: NATS streams don't support deletion - this just marks it in a separate subject
func (np *NATSProvider) Delete(ctx context.Context, id string) error {
subject := fmt.Sprintf("%s.deleted.%s", np.subjectPrefix, id)
deleteMsg := map[string]interface{}{
"event_id": id,
"deleted_at": time.Now(),
}
data, err := json.Marshal(deleteMsg)
if err != nil {
return fmt.Errorf("failed to marshal delete message: %w", err)
}
if _, err := np.js.Publish(ctx, subject, data); err != nil {
return fmt.Errorf("failed to publish delete message: %w", err)
}
return nil
}
// Stream returns a channel of events for real-time consumption
func (np *NATSProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
ch := make(chan *Event, 100)
// Convert glob pattern to NATS subject pattern
natsSubject := np.patternToSubject(pattern)
// Create durable consumer
consumerName := fmt.Sprintf("consumer-%s-%d", np.instanceID, time.Now().UnixNano())
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
Name: consumerName,
FilterSubject: natsSubject,
DeliverPolicy: jetstream.DeliverNewPolicy,
AckPolicy: jetstream.AckExplicitPolicy,
AckWait: 30 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create consumer: %w", err)
}
subCtx, cancel := context.WithCancel(ctx)
sub := &natsSubscription{
pattern: pattern,
consumer: consumer,
ch: ch,
ctx: subCtx,
cancel: cancel,
}
np.mu.Lock()
np.subscribers[pattern] = sub
np.stats.ActiveSubscribers.Add(1)
np.mu.Unlock()
// Start consumer goroutine
np.wg.Add(1)
go np.consumeMessages(sub)
return ch, nil
}
// Publish publishes an event to all subscribers
func (np *NATSProvider) Publish(ctx context.Context, event *Event) error {
// Store the event first
if err := np.Store(ctx, event); err != nil {
return err
}
np.stats.EventsPublished.Add(1)
return nil
}
// Close closes the provider and releases resources
func (np *NATSProvider) Close() error {
if !np.isRunning.Load() {
return nil
}
np.isRunning.Store(false)
// Cancel all subscriptions
np.mu.Lock()
for _, sub := range np.subscribers {
sub.cancel()
}
np.mu.Unlock()
// Wait for goroutines
np.wg.Wait()
// Close NATS connection
np.nc.Close()
logger.Info("NATS provider closed")
return nil
}
// Stats returns provider statistics
func (np *NATSProvider) Stats(ctx context.Context) (*ProviderStats, error) {
streamInfo, err := np.stream.Info(ctx)
if err != nil {
logger.Warn("Failed to get stream info: %v", err)
}
stats := &ProviderStats{
ProviderType: "nats",
TotalEvents: np.stats.TotalEvents.Load(),
EventsPublished: np.stats.EventsPublished.Load(),
EventsConsumed: np.stats.EventsConsumed.Load(),
ActiveSubscribers: int(np.stats.ActiveSubscribers.Load()),
ProviderSpecific: map[string]interface{}{
"stream_name": np.streamName,
"subject_prefix": np.subjectPrefix,
"max_age": np.maxAge.String(),
"consumer_errors": np.stats.ConsumerErrors.Load(),
},
}
if streamInfo != nil {
stats.ProviderSpecific["messages"] = streamInfo.State.Msgs
stats.ProviderSpecific["bytes"] = streamInfo.State.Bytes
stats.ProviderSpecific["consumers"] = streamInfo.State.Consumers
}
return stats, nil
}
// ensureStream creates or updates the JetStream stream
func (np *NATSProvider) ensureStream(ctx context.Context, storage jetstream.StorageType) error {
streamConfig := jetstream.StreamConfig{
Name: np.streamName,
Subjects: []string{np.subjectPrefix + ".>"},
MaxAge: np.maxAge,
Storage: storage,
Retention: jetstream.LimitsPolicy,
Discard: jetstream.DiscardOld,
}
stream, err := np.js.CreateStream(ctx, streamConfig)
if err != nil {
// Try to update if already exists
stream, err = np.js.UpdateStream(ctx, streamConfig)
if err != nil {
return fmt.Errorf("failed to create/update stream: %w", err)
}
}
np.stream = stream
return nil
}
// consumeMessages consumes messages from NATS for a subscription
func (np *NATSProvider) consumeMessages(sub *natsSubscription) {
defer np.wg.Done()
defer close(sub.ch)
defer func() {
np.mu.Lock()
delete(np.subscribers, sub.pattern)
np.stats.ActiveSubscribers.Add(-1)
np.mu.Unlock()
}()
logger.Debug("Starting NATS consumer for pattern: %s", sub.pattern)
// Consume messages
cc, err := sub.consumer.Consume(func(msg jetstream.Msg) {
var event Event
if err := json.Unmarshal(msg.Data(), &event); err != nil {
logger.Warn("Failed to unmarshal event: %v", err)
_ = msg.Nak()
return
}
// Check if event matches pattern (additional filtering)
if matchPattern(sub.pattern, event.Type) {
select {
case sub.ch <- &event:
np.stats.EventsConsumed.Add(1)
_ = msg.Ack()
case <-sub.ctx.Done():
_ = msg.Nak()
return
}
} else {
_ = msg.Ack()
}
})
if err != nil {
np.stats.ConsumerErrors.Add(1)
logger.Error("Failed to start consumer: %v", err)
return
}
// Wait for context cancellation
<-sub.ctx.Done()
// Stop consuming
cc.Stop()
logger.Debug("NATS consumer stopped for pattern: %s", sub.pattern)
}
// buildSubject creates a NATS subject from an event
// Format: events.{source}.{schema}.{entity}.{operation}
func (np *NATSProvider) buildSubject(event *Event) string {
return fmt.Sprintf("%s.%s.%s.%s.%s",
np.subjectPrefix,
event.Source,
event.Schema,
event.Entity,
event.Operation,
)
}
// patternToSubject converts a glob pattern to NATS subject pattern
// Examples:
// - "*" -> "events.>"
// - "public.users.*" -> "events.*.public.users.*"
// - "public.*.*" -> "events.*.public.*.*"
func (np *NATSProvider) patternToSubject(pattern string) string {
if pattern == "*" {
return np.subjectPrefix + ".>"
}
// For specific patterns, we need to match the event type structure
// Event type: schema.entity.operation
// NATS subject: events.{source}.{schema}.{entity}.{operation}
// We use wildcard for source since pattern doesn't include it
return fmt.Sprintf("%s.*.%s", np.subjectPrefix, pattern)
}
// matchesFilter checks if an event matches the filter criteria
func (np *NATSProvider) matchesFilter(event *Event, filter *EventFilter) bool {
if filter == nil {
return true
}
if filter.Source != nil && event.Source != *filter.Source {
return false
}
if filter.Status != nil && event.Status != *filter.Status {
return false
}
if filter.UserID != nil && event.UserID != *filter.UserID {
return false
}
if filter.Schema != "" && event.Schema != filter.Schema {
return false
}
if filter.Entity != "" && event.Entity != filter.Entity {
return false
}
if filter.Operation != "" && event.Operation != filter.Operation {
return false
}
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
return false
}
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
return false
}
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
return false
}
return true
}

View File

@@ -0,0 +1,541 @@
package eventbroker
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// RedisProvider implements Provider interface using Redis Streams
// Features:
// - Persistent event storage using Redis Streams
// - Cross-instance pub/sub using consumer groups
// - Pattern-based subscription routing
// - Automatic stream trimming to prevent unbounded growth
type RedisProvider struct {
client *redis.Client
streamName string
consumerGroup string
consumerName string
instanceID string
maxLen int64
// Subscriptions
mu sync.RWMutex
subscribers map[string]*redisSubscription
// Statistics
stats RedisProviderStats
// Lifecycle
stopListeners chan struct{}
wg sync.WaitGroup
isRunning atomic.Bool
}
// RedisProviderStats contains statistics for the Redis provider
type RedisProviderStats struct {
TotalEvents atomic.Int64
EventsPublished atomic.Int64
EventsConsumed atomic.Int64
ActiveSubscribers atomic.Int32
ConsumerErrors atomic.Int64
}
// redisSubscription represents a single subscription
type redisSubscription struct {
pattern string
ch chan *Event
ctx context.Context
cancel context.CancelFunc
}
// RedisProviderConfig configures the Redis provider
type RedisProviderConfig struct {
Host string
Port int
Password string
DB int
StreamName string
ConsumerGroup string
ConsumerName string
InstanceID string
MaxLen int64 // Maximum stream length (0 = unlimited)
}
// NewRedisProvider creates a new Redis event provider
func NewRedisProvider(cfg RedisProviderConfig) (*RedisProvider, error) {
// Apply defaults
if cfg.Host == "" {
cfg.Host = "localhost"
}
if cfg.Port == 0 {
cfg.Port = 6379
}
if cfg.StreamName == "" {
cfg.StreamName = "resolvespec:events"
}
if cfg.ConsumerGroup == "" {
cfg.ConsumerGroup = "resolvespec-workers"
}
if cfg.ConsumerName == "" {
cfg.ConsumerName = cfg.InstanceID
}
if cfg.MaxLen == 0 {
cfg.MaxLen = 10000 // Default max stream length
}
// Create Redis client
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password,
DB: cfg.DB,
PoolSize: 10,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
rp := &RedisProvider{
client: client,
streamName: cfg.StreamName,
consumerGroup: cfg.ConsumerGroup,
consumerName: cfg.ConsumerName,
instanceID: cfg.InstanceID,
maxLen: cfg.MaxLen,
subscribers: make(map[string]*redisSubscription),
stopListeners: make(chan struct{}),
}
rp.isRunning.Store(true)
// Create consumer group if it doesn't exist
if err := rp.ensureConsumerGroup(ctx); err != nil {
logger.Warn("Failed to create consumer group: %v (may already exist)", err)
}
logger.Info("Redis provider initialized (stream: %s, consumer_group: %s, consumer: %s)",
cfg.StreamName, cfg.ConsumerGroup, cfg.ConsumerName)
return rp, nil
}
// Store stores an event
func (rp *RedisProvider) Store(ctx context.Context, event *Event) error {
// Marshal event to JSON
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
// Store in Redis Stream
args := &redis.XAddArgs{
Stream: rp.streamName,
MaxLen: rp.maxLen,
Approx: true, // Use approximate trimming for better performance
Values: map[string]interface{}{
"event": data,
"id": event.ID,
"type": event.Type,
"source": string(event.Source),
"status": string(event.Status),
"instance_id": event.InstanceID,
},
}
if _, err := rp.client.XAdd(ctx, args).Result(); err != nil {
return fmt.Errorf("failed to add event to stream: %w", err)
}
rp.stats.TotalEvents.Add(1)
return nil
}
// Get retrieves an event by ID
// Note: This scans the stream which can be slow for large streams
// Consider using a separate hash for fast lookups if needed
func (rp *RedisProvider) Get(ctx context.Context, id string) (*Event, error) {
// Scan stream for event with matching ID
args := &redis.XReadArgs{
Streams: []string{rp.streamName, "0"},
Count: 1000, // Read in batches
}
for {
streams, err := rp.client.XRead(ctx, args).Result()
if err == redis.Nil {
return nil, fmt.Errorf("event not found: %s", id)
}
if err != nil {
return nil, fmt.Errorf("failed to read stream: %w", err)
}
if len(streams) == 0 {
return nil, fmt.Errorf("event not found: %s", id)
}
for _, stream := range streams {
for _, message := range stream.Messages {
// Check if this is the event we're looking for
if eventID, ok := message.Values["id"].(string); ok && eventID == id {
// Parse event
if eventData, ok := message.Values["event"].(string); ok {
var event Event
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
return &event, nil
}
}
}
// If we've read messages, update start position for next iteration
if len(stream.Messages) > 0 {
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
} else {
// No more messages
return nil, fmt.Errorf("event not found: %s", id)
}
}
}
}
// List lists events with optional filters
// Note: This scans the entire stream which can be slow
// Consider using time-based or ID-based ranges for better performance
func (rp *RedisProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
var results []*Event
// Read from stream
args := &redis.XReadArgs{
Streams: []string{rp.streamName, "0"},
Count: 1000,
}
for {
streams, err := rp.client.XRead(ctx, args).Result()
if err == redis.Nil {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read stream: %w", err)
}
if len(streams) == 0 {
break
}
for _, stream := range streams {
for _, message := range stream.Messages {
if eventData, ok := message.Values["event"].(string); ok {
var event Event
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
logger.Warn("Failed to unmarshal event: %v", err)
continue
}
if rp.matchesFilter(&event, filter) {
results = append(results, &event)
}
}
}
// Update start position for next iteration
if len(stream.Messages) > 0 {
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
} else {
// No more messages
goto done
}
}
}
done:
// Apply limit and offset
if filter != nil {
if filter.Offset > 0 && filter.Offset < len(results) {
results = results[filter.Offset:]
}
if filter.Limit > 0 && filter.Limit < len(results) {
results = results[:filter.Limit]
}
}
return results, nil
}
// UpdateStatus updates the status of an event
// Note: Redis Streams are append-only, so we need to store status updates separately
// This uses a separate hash for status tracking
func (rp *RedisProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
fields := map[string]interface{}{
"status": string(status),
"updated_at": time.Now().Format(time.RFC3339),
}
if errorMsg != "" {
fields["error"] = errorMsg
}
if err := rp.client.HSet(ctx, statusKey, fields).Err(); err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
// Set TTL on status key to prevent unbounded growth
rp.client.Expire(ctx, statusKey, 7*24*time.Hour) // 7 days
return nil
}
// Delete deletes an event by ID
// Note: Redis Streams don't support deletion by field value
// This marks the event as deleted in a separate set
func (rp *RedisProvider) Delete(ctx context.Context, id string) error {
deletedKey := fmt.Sprintf("%s:deleted", rp.streamName)
if err := rp.client.SAdd(ctx, deletedKey, id).Err(); err != nil {
return fmt.Errorf("failed to mark event as deleted: %w", err)
}
// Also delete the status hash if it exists
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
rp.client.Del(ctx, statusKey)
return nil
}
// Stream returns a channel of events for real-time consumption
// Uses Redis Streams consumer group for distributed processing
func (rp *RedisProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
ch := make(chan *Event, 100)
subCtx, cancel := context.WithCancel(ctx)
sub := &redisSubscription{
pattern: pattern,
ch: ch,
ctx: subCtx,
cancel: cancel,
}
rp.mu.Lock()
rp.subscribers[pattern] = sub
rp.stats.ActiveSubscribers.Add(1)
rp.mu.Unlock()
// Start consumer goroutine
rp.wg.Add(1)
go rp.consumeStream(sub)
return ch, nil
}
// Publish publishes an event to all subscribers (cross-instance)
func (rp *RedisProvider) Publish(ctx context.Context, event *Event) error {
// Store the event first
if err := rp.Store(ctx, event); err != nil {
return err
}
rp.stats.EventsPublished.Add(1)
return nil
}
// Close closes the provider and releases resources
func (rp *RedisProvider) Close() error {
if !rp.isRunning.Load() {
return nil
}
rp.isRunning.Store(false)
// Cancel all subscriptions
rp.mu.Lock()
for _, sub := range rp.subscribers {
sub.cancel()
}
rp.mu.Unlock()
// Stop listeners
close(rp.stopListeners)
// Wait for goroutines
rp.wg.Wait()
// Close Redis client
if err := rp.client.Close(); err != nil {
return fmt.Errorf("failed to close Redis client: %w", err)
}
logger.Info("Redis provider closed")
return nil
}
// Stats returns provider statistics
func (rp *RedisProvider) Stats(ctx context.Context) (*ProviderStats, error) {
// Get stream info
streamInfo, err := rp.client.XInfoStream(ctx, rp.streamName).Result()
if err != nil && err != redis.Nil {
logger.Warn("Failed to get stream info: %v", err)
}
stats := &ProviderStats{
ProviderType: "redis",
TotalEvents: rp.stats.TotalEvents.Load(),
EventsPublished: rp.stats.EventsPublished.Load(),
EventsConsumed: rp.stats.EventsConsumed.Load(),
ActiveSubscribers: int(rp.stats.ActiveSubscribers.Load()),
ProviderSpecific: map[string]interface{}{
"stream_name": rp.streamName,
"consumer_group": rp.consumerGroup,
"consumer_name": rp.consumerName,
"max_len": rp.maxLen,
"consumer_errors": rp.stats.ConsumerErrors.Load(),
},
}
if streamInfo != nil {
stats.ProviderSpecific["stream_length"] = streamInfo.Length
stats.ProviderSpecific["first_entry_id"] = streamInfo.FirstEntry.ID
stats.ProviderSpecific["last_entry_id"] = streamInfo.LastEntry.ID
}
return stats, nil
}
// consumeStream consumes events from the Redis Stream for a subscription
func (rp *RedisProvider) consumeStream(sub *redisSubscription) {
defer rp.wg.Done()
defer close(sub.ch)
defer func() {
rp.mu.Lock()
delete(rp.subscribers, sub.pattern)
rp.stats.ActiveSubscribers.Add(-1)
rp.mu.Unlock()
}()
logger.Debug("Starting stream consumer for pattern: %s", sub.pattern)
// Use consumer group for distributed processing
for {
select {
case <-sub.ctx.Done():
logger.Debug("Stream consumer stopped for pattern: %s", sub.pattern)
return
default:
// Read from consumer group
args := &redis.XReadGroupArgs{
Group: rp.consumerGroup,
Consumer: rp.consumerName,
Streams: []string{rp.streamName, ">"},
Count: 10,
Block: 1 * time.Second,
}
streams, err := rp.client.XReadGroup(sub.ctx, args).Result()
if err == redis.Nil {
continue
}
if err != nil {
if sub.ctx.Err() != nil {
return
}
rp.stats.ConsumerErrors.Add(1)
logger.Warn("Failed to read from consumer group: %v", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, message := range stream.Messages {
if eventData, ok := message.Values["event"].(string); ok {
var event Event
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
logger.Warn("Failed to unmarshal event: %v", err)
// Acknowledge message anyway to prevent redelivery
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
continue
}
// Check if event matches pattern
if matchPattern(sub.pattern, event.Type) {
select {
case sub.ch <- &event:
rp.stats.EventsConsumed.Add(1)
// Acknowledge message
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
case <-sub.ctx.Done():
return
}
} else {
// Acknowledge message even if it doesn't match pattern
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
}
}
}
}
}
}
}
// ensureConsumerGroup creates the consumer group if it doesn't exist
func (rp *RedisProvider) ensureConsumerGroup(ctx context.Context) error {
// Try to create the stream and consumer group
// MKSTREAM creates the stream if it doesn't exist
err := rp.client.XGroupCreateMkStream(ctx, rp.streamName, rp.consumerGroup, "0").Err()
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
return err
}
return nil
}
// matchesFilter checks if an event matches the filter criteria
func (rp *RedisProvider) matchesFilter(event *Event, filter *EventFilter) bool {
if filter == nil {
return true
}
if filter.Source != nil && event.Source != *filter.Source {
return false
}
if filter.Status != nil && event.Status != *filter.Status {
return false
}
if filter.UserID != nil && event.UserID != *filter.UserID {
return false
}
if filter.Schema != "" && event.Schema != filter.Schema {
return false
}
if filter.Entity != "" && event.Entity != filter.Entity {
return false
}
if filter.Operation != "" && event.Operation != filter.Operation {
return false
}
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
return false
}
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
return false
}
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
return false
}
return true
}

View File

@@ -84,7 +84,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
// Create local copy to avoid modifying the captured parameter across requests
sqlquery := sqlquery
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
var dbobjlist []map[string]interface{}
@@ -423,7 +423,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
// Create local copy to avoid modifying the captured parameter across requests
sqlquery := sqlquery
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
propQry := make(map[string]string)
@@ -522,10 +522,17 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
if strings.HasPrefix(kLower, "x-fieldfilter-") {
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
if strings.Contains(strings.ToLower(sqlquery), colname) {
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
switch val {
case "0":
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
case "":
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
default:
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
@@ -662,7 +669,10 @@ func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables ma
for k, v := range pathVars {
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v))
// Sanitize the value before replacing
vStr := fmt.Sprintf("%v", v)
sanitized := ValidSQL(vStr, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
variables[k] = v
@@ -690,7 +700,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
// Replace in SQL if placeholder exists
if strings.Contains(sqlquery, kword) && len(val) > 0 {
if strings.HasPrefix(parmk, "p-") {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
// Sanitize the parameter value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
}
@@ -702,15 +714,36 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
// Apply filters if allowed
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
if len(parmv) > 1 {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
// Sanitize each value in the IN clause with appropriate quoting
sanitizedValues := make([]string, len(parmv))
for i, v := range parmv {
if IsNumeric(v) {
// Numeric values don't need quotes
sanitizedValues[i] = ValidSQL(v, "colvalue")
} else {
// String values need quotes
sanitized := ValidSQL(v, "colvalue")
sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized)
}
}
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ",")))
} else {
if strings.Contains(val, "match=") {
colval := strings.ReplaceAll(val, "match=", "")
// Escape single quotes and backslashes for LIKE patterns
// But don't escape wildcards % and _ which are intentional
colval = strings.ReplaceAll(colval, "\\", "\\\\")
colval = strings.ReplaceAll(colval, "'", "''")
if colval != "*" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
}
} else if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
// For empty/zero values, treat as literal 0 or empty string with quotes
if val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
}
} else {
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
@@ -743,16 +776,25 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
// Sanitize the header value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
// Handle special headers
if strings.Contains(k, "x-fieldfilter-") {
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
switch val {
case "0":
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
case "":
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
default:
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
@@ -782,12 +824,15 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
if strings.Contains(sqlquery, "[p_meta_default]") {
data, _ := json.Marshal(metainfo)
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
dataStr := strings.ReplaceAll(string(data), "$META$", "/*META*/")
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("$META$%s$META$::jsonb", dataStr))
}
if strings.Contains(sqlquery, "[json_variables]") {
data, _ := json.Marshal(variables)
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
dataStr := strings.ReplaceAll(string(data), "$VAR$", "/*VAR*/")
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("$VAR$%s$VAR$::jsonb", dataStr))
}
if strings.Contains(sqlquery, "[rid_user]") {
@@ -795,7 +840,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
}
if strings.Contains(sqlquery, "[user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("$USR$%s$USR$", strings.ReplaceAll(userCtx.UserName, "$USR$", "/*USR*/")))
}
if strings.Contains(sqlquery, "[rid_session]") {
@@ -806,7 +851,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
}
if strings.Contains(sqlquery, "[method]") {
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
sqlquery = strings.ReplaceAll(sqlquery, "[method]", fmt.Sprintf("$M$%s$M$", strings.ReplaceAll(r.Method, "$M$", "/*M*/")))
}
if strings.Contains(sqlquery, "[post_body]") {
@@ -819,7 +864,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
}
}
}
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("$PBODY$%s$PBODY$", strings.ReplaceAll(bodystr, "$PBODY$", "/*PBODY*/")))
}
return sqlquery
@@ -859,19 +904,23 @@ func ValidSQL(input, mode string) string {
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
return reg.ReplaceAllString(input, "")
case "colvalue":
// For column values, escape single quotes
return strings.ReplaceAll(input, "'", "''")
// For column values, escape single quotes and backslashes
// Note: Backslashes must be escaped first, then single quotes
result := strings.ReplaceAll(input, "\\", "\\\\")
result = strings.ReplaceAll(result, "'", "''")
return result
case "select":
// For SELECT clauses, be more permissive but still safe
// Remove semicolons and common SQL injection patterns
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
result := input
for _, d := range dangerous {
result = strings.ReplaceAll(result, d, "")
result = strings.ReplaceAll(result, strings.ToLower(d), "")
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
// Remove semicolons and common SQL injection patterns (case-insensitive)
dangerous := []string{
";", "--", "/\\*", "\\*/", "xp_", "sp_",
"drop ", "delete ", "truncate ", "update ", "insert ",
"exec ", "execute ", "union ", "declare ", "alter ", "create ",
}
return result
// Build a single regex pattern with all dangerous keywords
pattern := "(?i)(" + strings.Join(dangerous, "|") + ")"
re := regexp.MustCompile(pattern)
return re.ReplaceAllString(input, "")
default:
return input
}

View File

@@ -75,6 +75,28 @@ func CloseErrorTracking() error {
return nil
}
// extractContext attempts to find a context.Context in the given arguments.
// It returns the found context (or context.Background() if not found) and
// the remaining arguments without the context.
func extractContext(args ...interface{}) (ctx context.Context, filteredArgs []interface{}) {
ctx = context.Background()
var newArgs []interface{}
found := false
for _, arg := range args {
if c, ok := arg.(context.Context); ok {
if !found {
ctx = c
found = true
}
// Ignore any additional context.Context arguments after the first one.
continue
}
newArgs = append(newArgs, arg)
}
return ctx, newArgs
}
func Info(template string, args ...interface{}) {
if Logger == nil {
log.Printf(template, args...)
@@ -84,7 +106,8 @@ func Info(template string, args ...interface{}) {
}
func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
ctx, remainingArgs := extractContext(args...)
message := fmt.Sprintf(template, remainingArgs...)
if Logger == nil {
log.Printf("%s", message)
} else {
@@ -93,14 +116,15 @@ func Warn(template string, args ...interface{}) {
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
}
}
func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
ctx, remainingArgs := extractContext(args...)
message := fmt.Sprintf(template, remainingArgs...)
if Logger == nil {
log.Printf("%s", message)
} else {
@@ -109,7 +133,7 @@ func Error(template string, args ...interface{}) {
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
}
@@ -124,34 +148,41 @@ func Debug(template string, args ...interface{}) {
}
// CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil {
callstack := debug.Stack()
// Returns a function that should be deferred to catch panics
// Example usage: defer CatchPanicCallback("MyFunction", func(err any) { /* cleanup */ })()
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) func() {
ctx, _ := extractContext(args...)
return func() {
if err := recover(); err != nil {
callstack := debug.Stack()
if Logger != nil {
Error("Panic in %s : %v", location, err)
} else {
fmt.Printf("%s:PANIC->%+v", location, err)
debug.PrintStack()
}
if Logger != nil {
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
} else {
fmt.Printf("%s:PANIC->%+v", location, err)
debug.PrintStack()
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}
if cb != nil {
cb(err)
if cb != nil {
cb(err)
}
}
}
}
// CatchPanic - Handle panic
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
// Returns a function that should be deferred to catch panics
// Example usage: defer CatchPanic("MyFunction")()
func CatchPanic(location string, args ...interface{}) func() {
return CatchPanicCallback(location, nil, args...)
}
// HandlePanic logs a panic and returns it as an error
@@ -163,13 +194,14 @@ func CatchPanic(location string) {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
func HandlePanic(methodName string, r any, args ...interface{}) error {
ctx, _ := extractContext(args...)
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})

View File

@@ -7,8 +7,8 @@ A pluggable metrics collection system with Prometheus implementation.
```go
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
// Initialize Prometheus provider
provider := metrics.NewPrometheusProvider()
// Initialize Prometheus provider with default config
provider := metrics.NewPrometheusProvider(nil)
metrics.SetProvider(provider)
// Apply middleware to your router
@@ -18,6 +18,59 @@ router.Use(provider.Middleware)
http.Handle("/metrics", provider.Handler())
```
## Configuration
You can customize the metrics provider using a configuration struct:
```go
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
// Create custom configuration
config := &metrics.Config{
Enabled: true,
Provider: "prometheus",
Namespace: "myapp", // Prefix all metrics with "myapp_"
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
}
// Initialize with custom config
provider := metrics.NewPrometheusProvider(config)
metrics.SetProvider(provider)
```
### Configuration Options
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `Enabled` | `bool` | `true` | Enable/disable metrics collection |
| `Provider` | `string` | `"prometheus"` | Metrics provider type |
| `Namespace` | `string` | `""` | Prefix for all metric names |
| `HTTPRequestBuckets` | `[]float64` | See below | Histogram buckets for HTTP duration (seconds) |
| `DBQueryBuckets` | `[]float64` | See below | Histogram buckets for DB query duration (seconds) |
**Default HTTP Request Buckets:** `[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]`
**Default DB Query Buckets:** `[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]`
### Pushgateway Configuration (Optional)
For batch jobs, cron tasks, or short-lived processes, you can push metrics to Prometheus Pushgateway:
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `PushgatewayURL` | `string` | `""` | URL of Pushgateway (e.g., "http://pushgateway:9091") |
| `PushgatewayJobName` | `string` | `"resolvespec"` | Job name for pushed metrics |
| `PushgatewayInterval` | `int` | `0` | Auto-push interval in seconds (0 = disabled) |
```go
config := &metrics.Config{
PushgatewayURL: "http://pushgateway:9091",
PushgatewayJobName: "batch-job",
PushgatewayInterval: 30, // Push every 30 seconds
}
```
## Provider Interface
The package uses a provider interface, allowing you to plug in different metric systems:
@@ -87,6 +140,13 @@ When using `PrometheusProvider`, the following metrics are available:
| `cache_hits_total` | Counter | provider | Total cache hits |
| `cache_misses_total` | Counter | provider | Total cache misses |
| `cache_size_items` | Gauge | provider | Current cache size |
| `events_published_total` | Counter | source, event_type | Total events published |
| `events_processed_total` | Counter | source, event_type, status | Total events processed |
| `event_processing_duration_seconds` | Histogram | source, event_type | Event processing duration |
| `event_queue_size` | Gauge | - | Current event queue size |
| `panics_total` | Counter | method | Total panics recovered |
**Note:** If a custom `Namespace` is configured, all metric names will be prefixed with `{namespace}_`.
## Prometheus Queries
@@ -146,8 +206,126 @@ func (c *CustomProvider) Handler() http.Handler {
metrics.SetProvider(&CustomProvider{})
```
## Pushgateway Usage
### Automatic Push (Batch Jobs)
For jobs that run periodically, use automatic pushing:
```go
package main
import (
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
func main() {
// Configure with automatic pushing every 30 seconds
config := &metrics.Config{
Enabled: true,
Provider: "prometheus",
Namespace: "batch_job",
PushgatewayURL: "http://pushgateway:9091",
PushgatewayJobName: "data-processor",
PushgatewayInterval: 30, // Push every 30 seconds
}
provider := metrics.NewPrometheusProvider(config)
metrics.SetProvider(provider)
// Ensure cleanup on exit
defer provider.StopAutoPush()
// Your batch job logic here
processBatchData()
}
```
### Manual Push (Short-lived Processes)
For one-time jobs or when you want manual control:
```go
package main
import (
"log"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
func main() {
// Configure without automatic pushing
config := &metrics.Config{
Enabled: true,
Provider: "prometheus",
PushgatewayURL: "http://pushgateway:9091",
PushgatewayJobName: "migration-job",
// PushgatewayInterval: 0 (default - no auto-push)
}
provider := metrics.NewPrometheusProvider(config)
metrics.SetProvider(provider)
// Run your job
err := runMigration()
// Push metrics at the end
if pushErr := provider.Push(); pushErr != nil {
log.Printf("Failed to push metrics: %v", pushErr)
}
if err != nil {
log.Fatal(err)
}
}
```
### Docker Compose with Pushgateway
```yaml
version: '3'
services:
batch-job:
build: .
environment:
PUSHGATEWAY_URL: "http://pushgateway:9091"
pushgateway:
image: prom/pushgateway
ports:
- "9091:9091"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
```
**prometheus.yml for Pushgateway:**
```yaml
global:
scrape_interval: 15s
scrape_configs:
# Scrape the pushgateway
- job_name: 'pushgateway'
honor_labels: true # Important: preserve job labels from pushed metrics
static_configs:
- targets: ['pushgateway:9091']
```
## Complete Example
### Basic Usage
```go
package main
@@ -162,8 +340,8 @@ import (
)
func main() {
// Initialize metrics
provider := metrics.NewPrometheusProvider()
// Initialize metrics with default config
provider := metrics.NewPrometheusProvider(nil)
metrics.SetProvider(provider)
// Create router
@@ -198,6 +376,42 @@ func getUsersHandler(w http.ResponseWriter, r *http.Request) {
}
```
### With Custom Configuration
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/gorilla/mux"
)
func main() {
// Custom metrics configuration
metricsConfig := &metrics.Config{
Enabled: true,
Provider: "prometheus",
Namespace: "myapp",
// Custom buckets optimized for your application
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5, 10},
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1},
}
// Initialize with custom config
provider := metrics.NewPrometheusProvider(metricsConfig)
metrics.SetProvider(provider)
router := mux.NewRouter()
router.Use(provider.Middleware)
router.Handle("/metrics", provider.Handler())
log.Fatal(http.ListenAndServe(":8080", router))
}
```
## Docker Compose Example
```yaml
@@ -257,3 +471,8 @@ scrape_configs:
4. **Performance**: Metrics collection is lock-free and highly performant
- Safe for high-throughput applications
- Minimal overhead (<1% in most cases)
5. **Pull vs Push**:
- **Use Pull (default)**: Long-running services, web servers, microservices
- **Use Push (Pushgateway)**: Batch jobs, cron tasks, short-lived processes, serverless functions
- Pull is preferred for most applications as it allows Prometheus to detect if your service is down

64
pkg/metrics/config.go Normal file
View File

@@ -0,0 +1,64 @@
package metrics
// Config holds configuration for the metrics provider
type Config struct {
// Enabled determines whether metrics collection is enabled
Enabled bool `mapstructure:"enabled"`
// Provider specifies which metrics provider to use (prometheus, noop)
Provider string `mapstructure:"provider"`
// Namespace is an optional prefix for all metric names
Namespace string `mapstructure:"namespace"`
// HTTPRequestBuckets defines histogram buckets for HTTP request duration (in seconds)
// Default: [0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]
HTTPRequestBuckets []float64 `mapstructure:"http_request_buckets"`
// DBQueryBuckets defines histogram buckets for database query duration (in seconds)
// Default: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]
DBQueryBuckets []float64 `mapstructure:"db_query_buckets"`
// PushgatewayURL is the URL of the Prometheus Pushgateway (optional)
// If set, metrics will be pushed to this gateway instead of only being scraped
// Example: "http://pushgateway:9091"
PushgatewayURL string `mapstructure:"pushgateway_url"`
// PushgatewayJobName is the job name to use when pushing metrics to Pushgateway
// Default: "resolvespec"
PushgatewayJobName string `mapstructure:"pushgateway_job_name"`
// PushgatewayInterval is the interval at which to push metrics to Pushgateway
// Only used if PushgatewayURL is set. If 0, automatic pushing is disabled.
// Default: 0 (no automatic pushing)
PushgatewayInterval int `mapstructure:"pushgateway_interval"`
}
// DefaultConfig returns a Config with sensible defaults
func DefaultConfig() *Config {
return &Config{
Enabled: true,
Provider: "prometheus",
// HTTP requests typically take longer than DB queries
HTTPRequestBuckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
// DB queries are usually faster
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5},
}
}
// ApplyDefaults fills in any missing values with defaults
func (c *Config) ApplyDefaults() {
if c.Provider == "" {
c.Provider = "prometheus"
}
if len(c.HTTPRequestBuckets) == 0 {
c.HTTPRequestBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
}
if len(c.DBQueryBuckets) == 0 {
c.DBQueryBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5}
}
// Set default job name if pushgateway is configured but job name is empty
if c.PushgatewayURL != "" && c.PushgatewayJobName == "" {
c.PushgatewayJobName = "resolvespec"
}
}

View File

@@ -0,0 +1,64 @@
package metrics_test
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
// ExampleNewPrometheusProvider_default demonstrates using default configuration
func ExampleNewPrometheusProvider_default() {
// Initialize with default configuration
provider := metrics.NewPrometheusProvider(nil)
metrics.SetProvider(provider)
fmt.Println("Provider initialized with defaults")
// Output: Provider initialized with defaults
}
// ExampleNewPrometheusProvider_custom demonstrates using custom configuration
func ExampleNewPrometheusProvider_custom() {
// Create custom configuration
config := &metrics.Config{
Enabled: true,
Provider: "prometheus",
Namespace: "myapp",
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
}
// Initialize with custom configuration
provider := metrics.NewPrometheusProvider(config)
metrics.SetProvider(provider)
fmt.Println("Provider initialized with custom config")
// Output: Provider initialized with custom config
}
// ExampleDefaultConfig demonstrates getting default configuration
func ExampleDefaultConfig() {
config := metrics.DefaultConfig()
fmt.Printf("Default provider: %s\n", config.Provider)
fmt.Printf("Default enabled: %v\n", config.Enabled)
// Output:
// Default provider: prometheus
// Default enabled: true
}
// ExampleConfig_ApplyDefaults demonstrates applying defaults to partial config
func ExampleConfig_ApplyDefaults() {
// Create partial configuration
config := &metrics.Config{
Namespace: "myapp",
// Other fields will be filled with defaults
}
// Apply defaults
config.ApplyDefaults()
fmt.Printf("Provider: %s\n", config.Provider)
fmt.Printf("Namespace: %s\n", config.Namespace)
// Output:
// Provider: prometheus
// Namespace: myapp
}

View File

@@ -39,6 +39,9 @@ type Provider interface {
// UpdateEventQueueSize updates the event queue size metric
UpdateEventQueueSize(size int64)
// RecordPanic records a panic event
RecordPanic(methodName string)
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
Handler() http.Handler
}
@@ -75,6 +78,7 @@ func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
func (n *NoOpProvider) RecordPanic(methodName string) {}
func (n *NoOpProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)

View File

@@ -8,6 +8,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/client_golang/prometheus/push"
)
// PrometheusProvider implements the Provider interface using Prometheus
@@ -20,22 +21,51 @@ type PrometheusProvider struct {
cacheHits *prometheus.CounterVec
cacheMisses *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
eventPublished *prometheus.CounterVec
eventProcessed *prometheus.CounterVec
eventDuration *prometheus.HistogramVec
eventQueueSize prometheus.Gauge
panicsTotal *prometheus.CounterVec
// Pushgateway fields (optional)
pushgatewayURL string
pushgatewayJobName string
pusher *push.Pusher
pushTicker *time.Ticker
pushStop chan bool
}
// NewPrometheusProvider creates a new Prometheus metrics provider
func NewPrometheusProvider() *PrometheusProvider {
return &PrometheusProvider{
// If cfg is nil, default configuration will be used
func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
// Use default config if none provided
if cfg == nil {
cfg = DefaultConfig()
} else {
// Apply defaults for any missing values
cfg.ApplyDefaults()
}
// Helper to add namespace prefix if configured
metricName := func(name string) string {
if cfg.Namespace != "" {
return cfg.Namespace + "_" + name
}
return name
}
p := &PrometheusProvider{
requestDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Name: metricName("http_request_duration_seconds"),
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
Buckets: cfg.HTTPRequestBuckets,
},
[]string{"method", "path", "status"},
),
requestTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Name: metricName("http_requests_total"),
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
@@ -43,47 +73,100 @@ func NewPrometheusProvider() *PrometheusProvider {
requestsInFlight: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "http_requests_in_flight",
Name: metricName("http_requests_in_flight"),
Help: "Current number of HTTP requests being processed",
},
),
dbQueryDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "db_query_duration_seconds",
Name: metricName("db_query_duration_seconds"),
Help: "Database query duration in seconds",
Buckets: prometheus.DefBuckets,
Buckets: cfg.DBQueryBuckets,
},
[]string{"operation", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "db_queries_total",
Name: metricName("db_queries_total"),
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Name: metricName("cache_hits_total"),
Help: "Total number of cache hits",
},
[]string{"provider"},
),
cacheMisses: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Name: metricName("cache_misses_total"),
Help: "Total number of cache misses",
},
[]string{"provider"},
),
cacheSize: promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "cache_size_items",
Name: metricName("cache_size_items"),
Help: "Number of items in cache",
},
[]string{"provider"},
),
eventPublished: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricName("events_published_total"),
Help: "Total number of events published",
},
[]string{"source", "event_type"},
),
eventProcessed: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricName("events_processed_total"),
Help: "Total number of events processed",
},
[]string{"source", "event_type", "status"},
),
eventDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: metricName("event_processing_duration_seconds"),
Help: "Event processing duration in seconds",
Buckets: cfg.DBQueryBuckets, // Events are typically fast like DB queries
},
[]string{"source", "event_type"},
),
eventQueueSize: promauto.NewGauge(
prometheus.GaugeOpts{
Name: metricName("event_queue_size"),
Help: "Current number of events in queue",
},
),
panicsTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: metricName("panics_total"),
Help: "Total number of panics",
},
[]string{"method"},
),
pushgatewayURL: cfg.PushgatewayURL,
pushgatewayJobName: cfg.PushgatewayJobName,
}
// Initialize pushgateway if configured
if cfg.PushgatewayURL != "" {
p.pusher = push.New(cfg.PushgatewayURL, cfg.PushgatewayJobName).
Gatherer(prometheus.DefaultGatherer)
// Start automatic pushing if interval is configured
if cfg.PushgatewayInterval > 0 {
p.pushStop = make(chan bool)
p.pushTicker = time.NewTicker(time.Duration(cfg.PushgatewayInterval) * time.Second)
go p.startAutoPush()
}
}
return p
}
// ResponseWriter wraps http.ResponseWriter to capture status code
@@ -145,6 +228,27 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
p.cacheSize.WithLabelValues(provider).Set(float64(size))
}
// RecordEventPublished implements Provider interface
func (p *PrometheusProvider) RecordEventPublished(source, eventType string) {
p.eventPublished.WithLabelValues(source, eventType).Inc()
}
// RecordEventProcessed implements Provider interface
func (p *PrometheusProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
p.eventProcessed.WithLabelValues(source, eventType, status).Inc()
p.eventDuration.WithLabelValues(source, eventType).Observe(duration.Seconds())
}
// UpdateEventQueueSize implements Provider interface
func (p *PrometheusProvider) UpdateEventQueueSize(size int64) {
p.eventQueueSize.Set(float64(size))
}
// RecordPanic implements the Provider interface
func (p *PrometheusProvider) RecordPanic(methodName string) {
p.panicsTotal.WithLabelValues(methodName).Inc()
}
// Handler implements Provider interface
func (p *PrometheusProvider) Handler() http.Handler {
return promhttp.Handler()
@@ -172,3 +276,37 @@ func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
})
}
// Push manually pushes metrics to the configured Pushgateway
// Returns an error if pushing fails or if Pushgateway is not configured
func (p *PrometheusProvider) Push() error {
if p.pusher == nil {
return nil // Pushgateway not configured, silently skip
}
return p.pusher.Push()
}
// startAutoPush runs in a goroutine and periodically pushes metrics to Pushgateway
func (p *PrometheusProvider) startAutoPush() {
for {
select {
case <-p.pushTicker.C:
if err := p.Push(); err != nil {
// Log error but continue pushing
// Note: In production, you might want to use a proper logger
_ = err
}
case <-p.pushStop:
p.pushTicker.Stop()
return
}
}
}
// StopAutoPush stops the automatic push goroutine
// This should be called when shutting down the application
func (p *PrometheusProvider) StopAutoPush() {
if p.pushStop != nil {
close(p.pushStop)
}
}

33
pkg/middleware/panic.go Normal file
View File

@@ -0,0 +1,33 @@
package middleware
import (
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
const panicMiddlewareMethodName = "PanicMiddleware"
// PanicRecovery is a middleware that recovers from panics, logs the error,
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
func PanicRecovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rcv := recover(); rcv != nil {
// Record the panic metric
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
// Log the panic and send to error tracker
// We pass the request context so the error tracker can potentially
// link the panic to the request trace.
ctx := r.Context()
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
// Respond with a 500 error
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,86 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/stretchr/testify/assert"
)
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
type mockMetricsProvider struct {
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
panicRecorded bool
methodName string
}
func (m *mockMetricsProvider) RecordPanic(methodName string) {
m.panicRecorded = true
m.methodName = methodName
}
func TestPanicRecovery(t *testing.T) {
// Initialize a mock logger to avoid actual logging output during tests
logger.Init(true)
// Setup mock metrics provider
mockProvider := &mockMetricsProvider{}
originalProvider := metrics.GetProvider()
metrics.SetProvider(mockProvider)
defer metrics.SetProvider(originalProvider) // Restore original provider after test
// 1. Test case: A handler that panics
t.Run("recovers from panic and returns 500", func(t *testing.T) {
// Reset mock state for this sub-test
mockProvider.panicRecorded = false
mockProvider.methodName = ""
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("something went terribly wrong")
})
// Create the middleware wrapping the panicking handler
testHandler := PanicRecovery(panicHandler)
// Create a test request and response recorder
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
// Serve the request
testHandler.ServeHTTP(rr, req)
// Assertions
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
// Assert that the metric was recorded
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
})
// 2. Test case: A handler that does NOT panic
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
// Reset mock state for this sub-test
mockProvider.panicRecorded = false
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
testHandler := PanicRecovery(successHandler)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
testHandler.ServeHTTP(rr, req)
// Assertions
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
})
}

724
pkg/mqttspec/README.md Normal file
View File

@@ -0,0 +1,724 @@
# MQTTSpec - MQTT-based Database Query Framework
MQTTSpec is an MQTT-based database query framework that enables real-time database operations and subscriptions via MQTT protocol. It mirrors the functionality of WebSocketSpec but uses MQTT as the transport layer, making it ideal for IoT applications, mobile apps with unreliable networks, and distributed systems requiring QoS guarantees.
## Features
- **Dual Broker Support**: Embedded broker (Mochi MQTT) or external broker connection (Paho MQTT)
- **QoS 1 (At-least-once delivery)**: Reliable message delivery for all operations
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
- **Database Agnostic**: GORM and Bun ORM support
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
- **Thread-safe**: Proper concurrency handling throughout
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/mqttspec
```
## Quick Start
### Embedded Broker (Default)
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/mqttspec"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}
func main() {
// Connect to database
db, _ := gorm.Open(postgres.Open("postgres://..."), &gorm.Config{})
db.AutoMigrate(&User{})
// Create MQTT handler with embedded broker
handler, err := mqttspec.NewHandlerWithGORM(db)
if err != nil {
panic(err)
}
// Register models
handler.Registry().RegisterModel("public.users", &User{})
// Start handler (starts embedded broker on localhost:1883)
if err := handler.Start(); err != nil {
panic(err)
}
// Handler is now listening for MQTT messages
select {} // Keep running
}
```
### External Broker
```go
handler, err := mqttspec.NewHandlerWithGORM(db,
mqttspec.WithExternalBroker(mqttspec.ExternalBrokerConfig{
BrokerURL: "tcp://mqtt.example.com:1883",
ClientID: "mqttspec-server",
Username: "admin",
Password: "secret",
ConnectTimeout: 10 * time.Second,
}),
)
```
### Custom Port (Embedded Broker)
```go
handler, err := mqttspec.NewHandlerWithGORM(db,
mqttspec.WithEmbeddedBroker(mqttspec.BrokerConfig{
Host: "0.0.0.0",
Port: 1884,
}),
)
```
## Topic Structure
MQTTSpec uses a client-based topic hierarchy:
```
spec/{client_id}/request # Client publishes requests
spec/{client_id}/response # Server publishes responses
spec/{client_id}/notify/{sub_id} # Server publishes notifications
```
### Wildcard Subscriptions
- **Server**: `spec/+/request` (receives all client requests)
- **Client**: `spec/{client_id}/response` + `spec/{client_id}/notify/+`
## Message Protocol
MQTTSpec uses the same JSON message structure as WebSocketSpec and ResolveSpec for consistency.
### Request Message
```json
{
"id": "msg-123",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
],
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 10
}
}
```
### Response Message
```json
{
"id": "msg-123",
"type": "response",
"success": true,
"data": [
{"id": 1, "name": "John Doe", "email": "john@example.com", "status": "active"},
{"id": 2, "name": "Jane Smith", "email": "jane@example.com", "status": "active"}
],
"metadata": {
"total": 50,
"count": 2
}
}
```
### Notification Message
```json
{
"type": "notification",
"operation": "create",
"subscription_id": "sub-xyz",
"schema": "public",
"entity": "users",
"data": {
"id": 3,
"name": "New User",
"email": "new@example.com",
"status": "active"
}
}
```
## CRUD Operations
### Read (Single Record)
**MQTT Client Publishes to**: `spec/{client_id}/request`
```json
{
"id": "msg-1",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"data": {"id": 1}
}
```
**Server Publishes Response to**: `spec/{client_id}/response`
```json
{
"id": "msg-1",
"success": true,
"data": {"id": 1, "name": "John Doe", "email": "john@example.com"}
}
```
### Read (Multiple Records with Filtering)
```json
{
"id": "msg-2",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
],
"sort": [{"column": "name", "direction": "asc"}],
"limit": 20,
"offset": 0
}
}
```
### Create
```json
{
"id": "msg-3",
"type": "request",
"operation": "create",
"schema": "public",
"entity": "users",
"data": {
"name": "Alice Brown",
"email": "alice@example.com",
"status": "active"
}
}
```
### Update
```json
{
"id": "msg-4",
"type": "request",
"operation": "update",
"schema": "public",
"entity": "users",
"data": {
"id": 1,
"status": "inactive"
}
}
```
### Delete
```json
{
"id": "msg-5",
"type": "request",
"operation": "delete",
"schema": "public",
"entity": "users",
"data": {"id": 1}
}
```
## Real-time Subscriptions
### Subscribe to Entity Changes
**Client Publishes to**: `spec/{client_id}/request`
```json
{
"id": "msg-6",
"type": "subscription",
"operation": "subscribe",
"schema": "public",
"entity": "users",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
]
}
}
```
**Server Response** (published to `spec/{client_id}/response`):
```json
{
"id": "msg-6",
"success": true,
"data": {
"subscription_id": "sub-abc123",
"notify_topic": "spec/{client_id}/notify/sub-abc123"
}
}
```
**Client Then Subscribes** to MQTT topic: `spec/{client_id}/notify/sub-abc123`
### Receiving Notifications
When any client creates/updates/deletes a user matching the subscription filters, the subscriber receives:
```json
{
"type": "notification",
"operation": "create",
"subscription_id": "sub-abc123",
"schema": "public",
"entity": "users",
"data": {
"id": 10,
"name": "New User",
"email": "newuser@example.com",
"status": "active"
}
}
```
### Unsubscribe
```json
{
"id": "msg-7",
"type": "subscription",
"operation": "unsubscribe",
"data": {
"subscription_id": "sub-abc123"
}
}
```
## Lifecycle Hooks
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
### Hook Types
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
- `BeforeRead` / `AfterRead` - Read operations
- `BeforeCreate` / `AfterCreate` - Create operations
- `BeforeUpdate` / `AfterUpdate` - Update operations
- `BeforeDelete` / `AfterDelete` - Delete operations
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
### Authentication Example (JWT)
```go
handler.Hooks().Register(mqttspec.BeforeConnect, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
// MQTT username contains JWT token
token := client.Username
claims, err := jwt.Validate(token)
if err != nil {
return fmt.Errorf("invalid token: %w", err)
}
// Store user info in client metadata for later use
client.SetMetadata("user_id", claims.UserID)
client.SetMetadata("tenant_id", claims.TenantID)
client.SetMetadata("roles", claims.Roles)
logger.Info("Client authenticated: user_id=%d, tenant=%s", claims.UserID, claims.TenantID)
return nil
})
```
### Multi-tenancy Example
```go
// Auto-inject tenant filter for all read operations
handler.Hooks().Register(mqttspec.BeforeRead, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
tenantID, _ := client.GetMetadata("tenant_id")
// Add tenant filter to ensure users only see their own data
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "tenant_id",
Operator: "eq",
Value: tenantID,
})
return nil
})
// Auto-set tenant_id for all create operations
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
tenantID, _ := client.GetMetadata("tenant_id")
// Inject tenant_id into new records
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["tenant_id"] = tenantID
}
return nil
})
```
### Role-based Access Control (RBAC)
```go
handler.Hooks().Register(mqttspec.BeforeDelete, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
roles, _ := client.GetMetadata("roles")
roleList := roles.([]string)
hasAdminRole := false
for _, role := range roleList {
if role == "admin" {
hasAdminRole = true
break
}
}
if !hasAdminRole {
return fmt.Errorf("permission denied: delete requires admin role")
}
return nil
})
```
### Audit Logging Example
```go
handler.Hooks().Register(mqttspec.AfterCreate, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
userID, _ := client.GetMetadata("user_id")
logger.Info("Audit: user %d created %s.%s record: %+v",
userID, ctx.Schema, ctx.Entity, ctx.Result)
// Could also write to audit log table
return nil
})
```
## Client Examples
### JavaScript (MQTT.js)
```javascript
const mqtt = require('mqtt');
// Connect to MQTT broker
const client = mqtt.connect('mqtt://localhost:1883', {
clientId: 'client-abc123',
username: 'your-jwt-token',
password: '', // JWT in username, password can be empty
});
client.on('connect', () => {
console.log('Connected to MQTT broker');
// Subscribe to responses
client.subscribe('spec/client-abc123/response');
// Read users
const readMsg = {
id: 'msg-1',
type: 'request',
operation: 'read',
schema: 'public',
entity: 'users',
options: {
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
]
}
};
client.publish('spec/client-abc123/request', JSON.stringify(readMsg));
});
client.on('message', (topic, payload) => {
const message = JSON.parse(payload.toString());
console.log('Received:', message);
if (message.type === 'response') {
console.log('Response data:', message.data);
} else if (message.type === 'notification') {
console.log('Notification:', message.operation, message.data);
}
});
```
### Python (paho-mqtt)
```python
import paho.mqtt.client as mqtt
import json
client_id = 'client-python-123'
def on_connect(client, userdata, flags, rc):
print(f"Connected with result code {rc}")
# Subscribe to responses
client.subscribe(f"spec/{client_id}/response")
# Create a user
create_msg = {
'id': 'msg-create-1',
'type': 'request',
'operation': 'create',
'schema': 'public',
'entity': 'users',
'data': {
'name': 'Python User',
'email': 'python@example.com',
'status': 'active'
}
}
client.publish(f"spec/{client_id}/request", json.dumps(create_msg))
def on_message(client, userdata, msg):
message = json.loads(msg.payload.decode())
print(f"Received on {msg.topic}: {message}")
client = mqtt.Client(client_id=client_id)
client.username_pw_set('your-jwt-token', '')
client.on_connect = on_connect
client.on_message = on_message
client.connect('localhost', 1883, 60)
client.loop_forever()
```
### Go (paho.mqtt.golang)
```go
package main
import (
"encoding/json"
"fmt"
"time"
mqtt "github.com/eclipse/paho.mqtt.golang"
)
func main() {
clientID := "client-go-123"
opts := mqtt.NewClientOptions()
opts.AddBroker("tcp://localhost:1883")
opts.SetClientID(clientID)
opts.SetUsername("your-jwt-token")
opts.SetPassword("")
opts.SetDefaultPublishHandler(func(client mqtt.Client, msg mqtt.Message) {
var message map[string]interface{}
json.Unmarshal(msg.Payload(), &message)
fmt.Printf("Received on %s: %+v\n", msg.Topic(), message)
})
opts.OnConnect = func(client mqtt.Client) {
fmt.Println("Connected to MQTT broker")
// Subscribe to responses
client.Subscribe(fmt.Sprintf("spec/%s/response", clientID), 1, nil)
// Read users
readMsg := map[string]interface{}{
"id": "msg-1",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"options": map[string]interface{}{
"filters": []map[string]interface{}{
{"column": "status", "operator": "eq", "value": "active"},
},
},
}
payload, _ := json.Marshal(readMsg)
client.Publish(fmt.Sprintf("spec/%s/request", clientID), 1, false, payload)
}
client := mqtt.NewClient(opts)
if token := client.Connect(); token.Wait() && token.Error() != nil {
panic(token.Error())
}
// Keep running
select {}
}
```
## Configuration Options
### BrokerConfig (Embedded Broker)
```go
type BrokerConfig struct {
Host string // Default: "localhost"
Port int // Default: 1883
EnableWebSocket bool // Enable WebSocket listener
WSPort int // WebSocket port (default: 1884)
MaxConnections int // Max concurrent connections
KeepAlive time.Duration // MQTT keep-alive interval
EnableAuth bool // Enable authentication
}
```
### ExternalBrokerConfig
```go
type ExternalBrokerConfig struct {
BrokerURL string // MQTT broker URL (tcp://host:port)
ClientID string // MQTT client ID
Username string // MQTT username
Password string // MQTT password
CleanSession bool // Clean session flag
KeepAlive time.Duration // Keep-alive interval
ConnectTimeout time.Duration // Connection timeout
ReconnectDelay time.Duration // Auto-reconnect delay
MaxReconnect int // Max reconnect attempts
TLSConfig *tls.Config // TLS configuration
}
```
### QoS Configuration
```go
handler, err := mqttspec.NewHandlerWithGORM(db,
mqttspec.WithQoS(1, 1, 1), // Request, Response, Notification
)
```
### Topic Prefix
```go
handler, err := mqttspec.NewHandlerWithGORM(db,
mqttspec.WithTopicPrefix("myapp"), // Changes topics to myapp/{client_id}/...
)
```
## Documentation References
- **ResolveSpec JSON Protocol**: See `/pkg/resolvespec/README.md` for the full message protocol specification
- **WebSocketSpec Documentation**: See `/pkg/websocketspec/README.md` for similar WebSocket-based implementation
- **Common Interfaces**: See `/pkg/common/types.go` for database adapter interfaces and query options
- **Model Registry**: See `/pkg/modelregistry/README.md` for model registration and reflection
- **Hooks Reference**: See `/pkg/websocketspec/hooks.go` for hook types (same as MQTTSpec)
- **Subscription Management**: See `/pkg/websocketspec/subscription.go` for subscription filtering
## Comparison: MQTTSpec vs WebSocketSpec
| Feature | MQTTSpec | WebSocketSpec |
|---------|----------|---------------|
| **Transport** | MQTT (pub/sub broker) | WebSocket (direct connection) |
| **Connection Model** | Broker-mediated | Direct client-server |
| **QoS Levels** | QoS 0, 1, 2 support | No built-in QoS |
| **Offline Messages** | Yes (with QoS 1+) | No |
| **Auto-reconnect** | Yes (built into MQTT) | Manual implementation needed |
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
| **Message Protocol** | Same JSON structure | Same JSON structure |
| **Hooks** | Same 12 hooks | Same 12 hooks |
| **CRUD Operations** | Identical | Identical |
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
## Use Cases
### IoT Sensor Data
```go
// Sensors publish data, backend stores and notifies subscribers
handler.Registry().RegisterModel("public.sensor_readings", &SensorReading{})
// Auto-set device_id from client metadata
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
deviceID, _ := client.GetMetadata("device_id")
if ctx.Entity == "sensor_readings" {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["device_id"] = deviceID
dataMap["timestamp"] = time.Now()
}
}
return nil
})
```
### Mobile App with Offline Support
MQTTSpec's QoS 1 ensures messages are delivered even if the client temporarily disconnects.
### Distributed Microservices
Multiple services can subscribe to entity changes and react accordingly.
## Testing
Run unit tests:
```bash
go test -v ./pkg/mqttspec
```
Run with race detection:
```bash
go test -race -v ./pkg/mqttspec
```
## License
This package is part of the ResolveSpec project.
## Contributing
Contributions are welcome! Please ensure:
- All tests pass (`go test ./pkg/mqttspec`)
- No race conditions (`go test -race ./pkg/mqttspec`)
- Documentation is updated
- Examples are provided for new features
## Support
For issues, questions, or feature requests, please open an issue in the ResolveSpec repository.

417
pkg/mqttspec/broker.go Normal file
View File

@@ -0,0 +1,417 @@
package mqttspec
import (
"context"
"fmt"
"sync"
"time"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/listeners"
pahomqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// BrokerInterface abstracts MQTT broker operations
type BrokerInterface interface {
// Start initializes the broker/client connection
Start(ctx context.Context) error
// Stop gracefully shuts down the broker/client
Stop(ctx context.Context) error
// Publish sends a message to a topic
Publish(topic string, qos byte, payload []byte) error
// Subscribe subscribes to a topic pattern with callback
Subscribe(topicFilter string, qos byte, callback MessageCallback) error
// Unsubscribe removes subscription
Unsubscribe(topicFilter string) error
// IsConnected returns connection status
IsConnected() bool
// GetClientManager returns the client manager
GetClientManager() *ClientManager
// SetHandler sets the handler reference (needed for hooks)
SetHandler(handler *Handler)
}
// MessageCallback is called when a message arrives
type MessageCallback func(topic string, payload []byte)
// EmbeddedBroker wraps Mochi MQTT server
type EmbeddedBroker struct {
config BrokerConfig
server *mqtt.Server
clientManager *ClientManager
handler *Handler
subscriptions map[string]MessageCallback
subMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
started bool
}
// NewEmbeddedBroker creates a new embedded broker
func NewEmbeddedBroker(config BrokerConfig, clientManager *ClientManager) *EmbeddedBroker {
return &EmbeddedBroker{
config: config,
clientManager: clientManager,
subscriptions: make(map[string]MessageCallback),
}
}
// SetHandler sets the handler reference
func (eb *EmbeddedBroker) SetHandler(handler *Handler) {
eb.mu.Lock()
defer eb.mu.Unlock()
eb.handler = handler
}
// Start starts the embedded MQTT broker
func (eb *EmbeddedBroker) Start(ctx context.Context) error {
eb.mu.Lock()
defer eb.mu.Unlock()
if eb.started {
return fmt.Errorf("broker already started")
}
eb.ctx, eb.cancel = context.WithCancel(ctx)
// Create Mochi MQTT server
eb.server = mqtt.New(&mqtt.Options{
InlineClient: true,
})
// Note: Authentication is handled at the handler level via BeforeConnect hook
// Mochi MQTT auth can be configured via custom hooks if needed
// Add TCP listener
tcp := listeners.NewTCP(
listeners.Config{
ID: "tcp",
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.Port),
},
)
if err := eb.server.AddListener(tcp); err != nil {
return fmt.Errorf("failed to add TCP listener: %w", err)
}
// Add WebSocket listener if enabled
if eb.config.EnableWebSocket {
ws := listeners.NewWebsocket(
listeners.Config{
ID: "ws",
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.WSPort),
},
)
if err := eb.server.AddListener(ws); err != nil {
return fmt.Errorf("failed to add WebSocket listener: %w", err)
}
}
// Start server in goroutine
go func() {
if err := eb.server.Serve(); err != nil {
logger.Error("[MQTTSpec] Embedded broker error: %v", err)
}
}()
// Wait for server to be ready
select {
case <-time.After(2 * time.Second):
// Server should be ready
case <-eb.ctx.Done():
return fmt.Errorf("context cancelled during startup")
}
eb.started = true
logger.Info("[MQTTSpec] Embedded broker started on %s:%d", eb.config.Host, eb.config.Port)
return nil
}
// Stop stops the embedded broker
func (eb *EmbeddedBroker) Stop(ctx context.Context) error {
eb.mu.Lock()
defer eb.mu.Unlock()
if !eb.started {
return nil
}
if eb.cancel != nil {
eb.cancel()
}
if eb.server != nil {
if err := eb.server.Close(); err != nil {
logger.Error("[MQTTSpec] Error closing embedded broker: %v", err)
}
}
eb.started = false
logger.Info("[MQTTSpec] Embedded broker stopped")
return nil
}
// Publish publishes a message to a topic
func (eb *EmbeddedBroker) Publish(topic string, qos byte, payload []byte) error {
if !eb.started {
return fmt.Errorf("broker not started")
}
if eb.server == nil {
return fmt.Errorf("server not initialized")
}
// Use inline client to publish
return eb.server.Publish(topic, payload, false, qos)
}
// Subscribe subscribes to a topic
func (eb *EmbeddedBroker) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
if !eb.started {
return fmt.Errorf("broker not started")
}
// Store callback
eb.subMu.Lock()
eb.subscriptions[topicFilter] = callback
eb.subMu.Unlock()
// Create inline subscription handler
// Note: Mochi MQTT internal subscriptions are more complex
// For now, we'll use a publishing hook to intercept messages
// This is a simplified implementation
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
return nil
}
// Unsubscribe unsubscribes from a topic
func (eb *EmbeddedBroker) Unsubscribe(topicFilter string) error {
eb.subMu.Lock()
defer eb.subMu.Unlock()
delete(eb.subscriptions, topicFilter)
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
return nil
}
// IsConnected returns whether the broker is running
func (eb *EmbeddedBroker) IsConnected() bool {
eb.mu.RLock()
defer eb.mu.RUnlock()
return eb.started
}
// GetClientManager returns the client manager
func (eb *EmbeddedBroker) GetClientManager() *ClientManager {
return eb.clientManager
}
// ExternalBrokerClient wraps Paho MQTT client
type ExternalBrokerClient struct {
config ExternalBrokerConfig
client pahomqtt.Client
clientManager *ClientManager
handler *Handler
subscriptions map[string]MessageCallback
subMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
connected bool
}
// NewExternalBrokerClient creates a new external broker client
func NewExternalBrokerClient(config ExternalBrokerConfig, clientManager *ClientManager) *ExternalBrokerClient {
return &ExternalBrokerClient{
config: config,
clientManager: clientManager,
subscriptions: make(map[string]MessageCallback),
}
}
// SetHandler sets the handler reference
func (ebc *ExternalBrokerClient) SetHandler(handler *Handler) {
ebc.mu.Lock()
defer ebc.mu.Unlock()
ebc.handler = handler
}
// Start connects to the external MQTT broker
func (ebc *ExternalBrokerClient) Start(ctx context.Context) error {
ebc.mu.Lock()
defer ebc.mu.Unlock()
if ebc.connected {
return fmt.Errorf("already connected")
}
ebc.ctx, ebc.cancel = context.WithCancel(ctx)
// Create Paho client options
opts := pahomqtt.NewClientOptions()
opts.AddBroker(ebc.config.BrokerURL)
opts.SetClientID(ebc.config.ClientID)
opts.SetUsername(ebc.config.Username)
opts.SetPassword(ebc.config.Password)
opts.SetCleanSession(ebc.config.CleanSession)
opts.SetKeepAlive(ebc.config.KeepAlive)
opts.SetAutoReconnect(true)
opts.SetMaxReconnectInterval(ebc.config.ReconnectDelay)
// Set connection lost handler
opts.SetConnectionLostHandler(func(client pahomqtt.Client, err error) {
logger.Error("[MQTTSpec] External broker connection lost: %v", err)
ebc.mu.Lock()
ebc.connected = false
ebc.mu.Unlock()
})
// Set on-connect handler
opts.SetOnConnectHandler(func(client pahomqtt.Client) {
logger.Info("[MQTTSpec] Connected to external broker")
ebc.mu.Lock()
ebc.connected = true
ebc.mu.Unlock()
// Resubscribe to topics
ebc.resubscribeAll()
})
// Create and connect client
ebc.client = pahomqtt.NewClient(opts)
token := ebc.client.Connect()
if !token.WaitTimeout(ebc.config.ConnectTimeout) {
return fmt.Errorf("connection timeout")
}
if err := token.Error(); err != nil {
return fmt.Errorf("failed to connect to external broker: %w", err)
}
ebc.connected = true
logger.Info("[MQTTSpec] Connected to external MQTT broker: %s", ebc.config.BrokerURL)
return nil
}
// Stop disconnects from the external broker
func (ebc *ExternalBrokerClient) Stop(ctx context.Context) error {
ebc.mu.Lock()
defer ebc.mu.Unlock()
if !ebc.connected {
return nil
}
if ebc.cancel != nil {
ebc.cancel()
}
if ebc.client != nil && ebc.client.IsConnected() {
ebc.client.Disconnect(uint(ebc.config.ConnectTimeout.Milliseconds()))
}
ebc.connected = false
logger.Info("[MQTTSpec] Disconnected from external broker")
return nil
}
// Publish publishes a message to a topic
func (ebc *ExternalBrokerClient) Publish(topic string, qos byte, payload []byte) error {
if !ebc.connected {
return fmt.Errorf("not connected to broker")
}
token := ebc.client.Publish(topic, qos, false, payload)
token.Wait()
return token.Error()
}
// Subscribe subscribes to a topic
func (ebc *ExternalBrokerClient) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
if !ebc.connected {
return fmt.Errorf("not connected to broker")
}
// Store callback
ebc.subMu.Lock()
ebc.subscriptions[topicFilter] = callback
ebc.subMu.Unlock()
// Subscribe via Paho client
token := ebc.client.Subscribe(topicFilter, qos, func(client pahomqtt.Client, msg pahomqtt.Message) {
callback(msg.Topic(), msg.Payload())
})
token.Wait()
if err := token.Error(); err != nil {
return fmt.Errorf("failed to subscribe to %s: %w", topicFilter, err)
}
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
return nil
}
// Unsubscribe unsubscribes from a topic
func (ebc *ExternalBrokerClient) Unsubscribe(topicFilter string) error {
ebc.subMu.Lock()
defer ebc.subMu.Unlock()
if ebc.client != nil && ebc.connected {
token := ebc.client.Unsubscribe(topicFilter)
token.Wait()
if err := token.Error(); err != nil {
logger.Error("[MQTTSpec] Failed to unsubscribe from %s: %v", topicFilter, err)
}
}
delete(ebc.subscriptions, topicFilter)
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
return nil
}
// IsConnected returns connection status
func (ebc *ExternalBrokerClient) IsConnected() bool {
ebc.mu.RLock()
defer ebc.mu.RUnlock()
return ebc.connected
}
// GetClientManager returns the client manager
func (ebc *ExternalBrokerClient) GetClientManager() *ClientManager {
return ebc.clientManager
}
// resubscribeAll resubscribes to all topics after reconnection
func (ebc *ExternalBrokerClient) resubscribeAll() {
ebc.subMu.RLock()
defer ebc.subMu.RUnlock()
for topicFilter, callback := range ebc.subscriptions {
logger.Info("[MQTTSpec] Resubscribing to topic: %s", topicFilter)
token := ebc.client.Subscribe(topicFilter, 1, func(client pahomqtt.Client, msg pahomqtt.Message) {
callback(msg.Topic(), msg.Payload())
})
if token.Wait() && token.Error() != nil {
logger.Error("[MQTTSpec] Failed to resubscribe to %s: %v", topicFilter, token.Error())
}
}
}

409
pkg/mqttspec/broker_test.go Normal file
View File

@@ -0,0 +1,409 @@
package mqttspec
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewEmbeddedBroker(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 1883,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
assert.NotNil(t, broker)
assert.Equal(t, config, broker.config)
assert.Equal(t, cm, broker.clientManager)
assert.NotNil(t, broker.subscriptions)
assert.False(t, broker.started)
}
func TestEmbeddedBroker_StartStop(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11883, // Use non-standard port for testing
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Start broker
err := broker.Start(ctx)
require.NoError(t, err)
// Verify started
assert.True(t, broker.IsConnected())
// Stop broker
err = broker.Stop(ctx)
require.NoError(t, err)
// Verify stopped
assert.False(t, broker.IsConnected())
}
func TestEmbeddedBroker_StartTwice(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11884,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Start broker
err := broker.Start(ctx)
require.NoError(t, err)
defer broker.Stop(ctx)
// Try to start again - should fail
err = broker.Start(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already started")
}
func TestEmbeddedBroker_StopWithoutStart(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11885,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Stop without starting - should not error
err := broker.Stop(ctx)
assert.NoError(t, err)
}
func TestEmbeddedBroker_PublishWithoutStart(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11886,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
// Try to publish without starting - should fail
err := broker.Publish("test/topic", 1, []byte("test"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "broker not started")
}
func TestEmbeddedBroker_SubscribeWithoutStart(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11887,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
// Try to subscribe without starting - should fail
err := broker.Subscribe("test/topic", 1, func(topic string, payload []byte) {})
assert.Error(t, err)
assert.Contains(t, err.Error(), "broker not started")
}
func TestEmbeddedBroker_PublishSubscribe(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11888,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Start broker
err := broker.Start(ctx)
require.NoError(t, err)
defer broker.Stop(ctx)
// Subscribe to topic
callback := func(topic string, payload []byte) {
// Callback for subscription - actual message delivery would require
// integration with Mochi MQTT's hook system
}
err = broker.Subscribe("test/topic", 1, callback)
require.NoError(t, err)
// Note: Embedded broker's Subscribe is simplified and doesn't fully integrate
// with Mochi MQTT's internal pub/sub. This test verifies the subscription
// is registered but actual message delivery would require more complex
// integration with Mochi MQTT's hook system.
// Verify subscription was registered
broker.subMu.RLock()
_, exists := broker.subscriptions["test/topic"]
broker.subMu.RUnlock()
assert.True(t, exists)
}
func TestEmbeddedBroker_Unsubscribe(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11889,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Start broker
err := broker.Start(ctx)
require.NoError(t, err)
defer broker.Stop(ctx)
// Subscribe
callback := func(topic string, payload []byte) {}
err = broker.Subscribe("test/topic", 1, callback)
require.NoError(t, err)
// Verify subscription exists
broker.subMu.RLock()
_, exists := broker.subscriptions["test/topic"]
broker.subMu.RUnlock()
assert.True(t, exists)
// Unsubscribe
err = broker.Unsubscribe("test/topic")
require.NoError(t, err)
// Verify subscription removed
broker.subMu.RLock()
_, exists = broker.subscriptions["test/topic"]
broker.subMu.RUnlock()
assert.False(t, exists)
}
func TestEmbeddedBroker_SetHandler(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11890,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
// Create a mock handler (nil is fine for this test)
var handler *Handler = nil
// Set handler
broker.SetHandler(handler)
// Verify handler was set
broker.mu.RLock()
assert.Equal(t, handler, broker.handler)
broker.mu.RUnlock()
}
func TestEmbeddedBroker_GetClientManager(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11891,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
// Get client manager
retrievedCM := broker.GetClientManager()
// Verify it's the same instance
assert.Equal(t, cm, retrievedCM)
}
func TestEmbeddedBroker_ConcurrentPublish(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := BrokerConfig{
Host: "localhost",
Port: 11892,
MaxConnections: 100,
KeepAlive: 60 * time.Second,
}
broker := NewEmbeddedBroker(config, cm)
ctx := context.Background()
// Start broker
err := broker.Start(ctx)
require.NoError(t, err)
defer broker.Stop(ctx)
// Test concurrent publishing
var wg sync.WaitGroup
numPublishers := 10
for i := 0; i < numPublishers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 10; j++ {
err := broker.Publish("test/topic", 1, []byte("test"))
// Errors are acceptable in concurrent scenario
_ = err
}
}(i)
}
wg.Wait()
}
func TestNewExternalBrokerClient(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := ExternalBrokerConfig{
BrokerURL: "tcp://localhost:1883",
ClientID: "test-client",
Username: "user",
Password: "pass",
CleanSession: true,
KeepAlive: 60 * time.Second,
ConnectTimeout: 5 * time.Second,
ReconnectDelay: 1 * time.Second,
}
broker := NewExternalBrokerClient(config, cm)
assert.NotNil(t, broker)
assert.Equal(t, config, broker.config)
assert.Equal(t, cm, broker.clientManager)
assert.NotNil(t, broker.subscriptions)
assert.False(t, broker.connected)
}
func TestExternalBrokerClient_SetHandler(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := ExternalBrokerConfig{
BrokerURL: "tcp://localhost:1883",
ClientID: "test-client",
Username: "user",
Password: "pass",
CleanSession: true,
KeepAlive: 60 * time.Second,
ConnectTimeout: 5 * time.Second,
ReconnectDelay: 1 * time.Second,
}
broker := NewExternalBrokerClient(config, cm)
// Create a mock handler (nil is fine for this test)
var handler *Handler = nil
// Set handler
broker.SetHandler(handler)
// Verify handler was set
broker.mu.RLock()
assert.Equal(t, handler, broker.handler)
broker.mu.RUnlock()
}
func TestExternalBrokerClient_GetClientManager(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := ExternalBrokerConfig{
BrokerURL: "tcp://localhost:1883",
ClientID: "test-client",
Username: "user",
Password: "pass",
CleanSession: true,
KeepAlive: 60 * time.Second,
ConnectTimeout: 5 * time.Second,
ReconnectDelay: 1 * time.Second,
}
broker := NewExternalBrokerClient(config, cm)
// Get client manager
retrievedCM := broker.GetClientManager()
// Verify it's the same instance
assert.Equal(t, cm, retrievedCM)
}
func TestExternalBrokerClient_IsConnected(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
config := ExternalBrokerConfig{
BrokerURL: "tcp://localhost:1883",
ClientID: "test-client",
Username: "user",
Password: "pass",
CleanSession: true,
KeepAlive: 60 * time.Second,
ConnectTimeout: 5 * time.Second,
ReconnectDelay: 1 * time.Second,
}
broker := NewExternalBrokerClient(config, cm)
// Should not be connected initially
assert.False(t, broker.IsConnected())
}
// Note: Tests for ExternalBrokerClient Start/Stop/Publish/Subscribe require
// a running MQTT broker and are better suited for integration tests.
// These tests would be included in integration_test.go with proper test
// broker setup (e.g., using Docker Compose).

184
pkg/mqttspec/client.go Normal file
View File

@@ -0,0 +1,184 @@
package mqttspec
import (
"context"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Client represents an MQTT client connection
type Client struct {
// ID is the MQTT client ID (unique per connection)
ID string
// Username from MQTT CONNECT packet
Username string
// ConnectedAt is when the client connected
ConnectedAt time.Time
// subscriptions holds active subscriptions for this client
subscriptions map[string]*Subscription
subMu sync.RWMutex
// metadata stores client-specific data (user_id, roles, tenant_id, etc.)
// Set by BeforeConnect hook for authentication/authorization
metadata map[string]interface{}
metaMu sync.RWMutex
// ctx is the client context
ctx context.Context
cancel context.CancelFunc
// handler reference for callback access
handler *Handler
}
// ClientManager manages all MQTT client connections
type ClientManager struct {
// clients maps client_id to Client
clients map[string]*Client
mu sync.RWMutex
// ctx for lifecycle management
ctx context.Context
cancel context.CancelFunc
}
// NewClient creates a new MQTT client
func NewClient(id, username string, handler *Handler) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
ID: id,
Username: username,
ConnectedAt: time.Now(),
subscriptions: make(map[string]*Subscription),
metadata: make(map[string]interface{}),
ctx: ctx,
cancel: cancel,
handler: handler,
}
}
// SetMetadata sets metadata for this client
func (c *Client) SetMetadata(key string, value interface{}) {
c.metaMu.Lock()
defer c.metaMu.Unlock()
c.metadata[key] = value
}
// GetMetadata retrieves metadata for this client
func (c *Client) GetMetadata(key string) (interface{}, bool) {
c.metaMu.RLock()
defer c.metaMu.RUnlock()
val, ok := c.metadata[key]
return val, ok
}
// AddSubscription adds a subscription to this client
func (c *Client) AddSubscription(sub *Subscription) {
c.subMu.Lock()
defer c.subMu.Unlock()
c.subscriptions[sub.ID] = sub
}
// RemoveSubscription removes a subscription from this client
func (c *Client) RemoveSubscription(subID string) {
c.subMu.Lock()
defer c.subMu.Unlock()
delete(c.subscriptions, subID)
}
// GetSubscription retrieves a subscription by ID
func (c *Client) GetSubscription(subID string) (*Subscription, bool) {
c.subMu.RLock()
defer c.subMu.RUnlock()
sub, ok := c.subscriptions[subID]
return sub, ok
}
// Close cleans up the client
func (c *Client) Close() {
if c.cancel != nil {
c.cancel()
}
// Clean up subscriptions
c.subMu.Lock()
for subID := range c.subscriptions {
if c.handler != nil && c.handler.subscriptionManager != nil {
c.handler.subscriptionManager.Unsubscribe(subID)
}
}
c.subscriptions = make(map[string]*Subscription)
c.subMu.Unlock()
}
// NewClientManager creates a new client manager
func NewClientManager(ctx context.Context) *ClientManager {
ctx, cancel := context.WithCancel(ctx)
return &ClientManager{
clients: make(map[string]*Client),
ctx: ctx,
cancel: cancel,
}
}
// Register registers a new MQTT client
func (cm *ClientManager) Register(clientID, username string, handler *Handler) *Client {
cm.mu.Lock()
defer cm.mu.Unlock()
client := NewClient(clientID, username, handler)
cm.clients[clientID] = client
count := len(cm.clients)
logger.Info("[MQTTSpec] Client registered: %s (username: %s, total: %d)", clientID, username, count)
return client
}
// Unregister removes a client
func (cm *ClientManager) Unregister(clientID string) {
cm.mu.Lock()
defer cm.mu.Unlock()
if client, ok := cm.clients[clientID]; ok {
client.Close()
delete(cm.clients, clientID)
count := len(cm.clients)
logger.Info("[MQTTSpec] Client unregistered: %s (total: %d)", clientID, count)
}
}
// GetClient retrieves a client by ID
func (cm *ClientManager) GetClient(clientID string) (*Client, bool) {
cm.mu.RLock()
defer cm.mu.RUnlock()
client, ok := cm.clients[clientID]
return client, ok
}
// Count returns the number of active clients
func (cm *ClientManager) Count() int {
cm.mu.RLock()
defer cm.mu.RUnlock()
return len(cm.clients)
}
// Shutdown gracefully shuts down the client manager
func (cm *ClientManager) Shutdown() {
cm.cancel()
// Close all clients
cm.mu.Lock()
for _, client := range cm.clients {
client.Close()
}
cm.clients = make(map[string]*Client)
cm.mu.Unlock()
logger.Info("[MQTTSpec] Client manager shut down")
}

256
pkg/mqttspec/client_test.go Normal file
View File

@@ -0,0 +1,256 @@
package mqttspec
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewClient(t *testing.T) {
client := NewClient("client-123", "user@example.com", nil)
assert.Equal(t, "client-123", client.ID)
assert.Equal(t, "user@example.com", client.Username)
assert.NotNil(t, client.subscriptions)
assert.NotNil(t, client.metadata)
assert.NotNil(t, client.ctx)
assert.NotNil(t, client.cancel)
}
func TestClient_Metadata(t *testing.T) {
client := NewClient("client-123", "user", nil)
// Set metadata
client.SetMetadata("user_id", 456)
client.SetMetadata("tenant_id", "tenant-abc")
client.SetMetadata("roles", []string{"admin", "user"})
// Get metadata
userID, exists := client.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 456, userID)
tenantID, exists := client.GetMetadata("tenant_id")
assert.True(t, exists)
assert.Equal(t, "tenant-abc", tenantID)
roles, exists := client.GetMetadata("roles")
assert.True(t, exists)
assert.Equal(t, []string{"admin", "user"}, roles)
// Non-existent key
_, exists = client.GetMetadata("nonexistent")
assert.False(t, exists)
}
func TestClient_Subscriptions(t *testing.T) {
client := NewClient("client-123", "user", nil)
// Create mock subscription
sub := &Subscription{
ID: "sub-1",
ConnectionID: "client-123",
Schema: "public",
Entity: "users",
Active: true,
}
// Add subscription
client.AddSubscription(sub)
// Get subscription
retrieved, exists := client.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", retrieved.ID)
// Remove subscription
client.RemoveSubscription("sub-1")
// Verify removed
_, exists = client.GetSubscription("sub-1")
assert.False(t, exists)
}
func TestClient_Close(t *testing.T) {
client := NewClient("client-123", "user", nil)
// Add some subscriptions
client.AddSubscription(&Subscription{ID: "sub-1"})
client.AddSubscription(&Subscription{ID: "sub-2"})
// Close client
client.Close()
// Verify subscriptions cleared
client.subMu.RLock()
assert.Empty(t, client.subscriptions)
client.subMu.RUnlock()
// Verify context cancelled
select {
case <-client.ctx.Done():
// Context was cancelled
default:
t.Fatal("Context should be cancelled after Close()")
}
}
func TestNewClientManager(t *testing.T) {
cm := NewClientManager(context.Background())
assert.NotNil(t, cm)
assert.NotNil(t, cm.clients)
assert.Equal(t, 0, cm.Count())
}
func TestClientManager_Register(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
client := cm.Register("client-1", "user@example.com", nil)
assert.NotNil(t, client)
assert.Equal(t, "client-1", client.ID)
assert.Equal(t, "user@example.com", client.Username)
assert.Equal(t, 1, cm.Count())
}
func TestClientManager_Unregister(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
cm.Register("client-1", "user1", nil)
assert.Equal(t, 1, cm.Count())
cm.Unregister("client-1")
assert.Equal(t, 0, cm.Count())
}
func TestClientManager_GetClient(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
cm.Register("client-1", "user1", nil)
// Get existing client
client, exists := cm.GetClient("client-1")
assert.True(t, exists)
assert.NotNil(t, client)
assert.Equal(t, "client-1", client.ID)
// Get non-existent client
_, exists = cm.GetClient("nonexistent")
assert.False(t, exists)
}
func TestClientManager_MultipleClients(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
cm.Register("client-1", "user1", nil)
cm.Register("client-2", "user2", nil)
cm.Register("client-3", "user3", nil)
assert.Equal(t, 3, cm.Count())
cm.Unregister("client-2")
assert.Equal(t, 2, cm.Count())
// Verify correct client was removed
_, exists := cm.GetClient("client-2")
assert.False(t, exists)
_, exists = cm.GetClient("client-1")
assert.True(t, exists)
_, exists = cm.GetClient("client-3")
assert.True(t, exists)
}
func TestClientManager_Shutdown(t *testing.T) {
cm := NewClientManager(context.Background())
cm.Register("client-1", "user1", nil)
cm.Register("client-2", "user2", nil)
assert.Equal(t, 2, cm.Count())
cm.Shutdown()
// All clients should be removed
assert.Equal(t, 0, cm.Count())
// Context should be cancelled
select {
case <-cm.ctx.Done():
// Context was cancelled
default:
t.Fatal("Context should be cancelled after Shutdown()")
}
}
func TestClientManager_ConcurrentOperations(t *testing.T) {
cm := NewClientManager(context.Background())
defer cm.Shutdown()
// This test verifies that concurrent operations don't cause race conditions
// Run with: go test -race
var wg sync.WaitGroup
// Goroutine 1: Register clients
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
cm.Register("client-"+string(rune(i)), "user", nil)
}
}()
// Goroutine 2: Get clients
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
cm.GetClient("client-" + string(rune(i)))
}
}()
// Goroutine 3: Count
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
cm.Count()
}
}()
wg.Wait()
}
func TestClient_ConcurrentMetadata(t *testing.T) {
client := NewClient("client-123", "user", nil)
var wg sync.WaitGroup
// Concurrent writes
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
client.SetMetadata("key1", i)
}
}()
// Concurrent reads
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
client.GetMetadata("key1")
}
}()
wg.Wait()
}

178
pkg/mqttspec/config.go Normal file
View File

@@ -0,0 +1,178 @@
package mqttspec
import (
"crypto/tls"
"time"
)
// BrokerMode specifies how to connect to MQTT
type BrokerMode string
const (
// BrokerModeEmbedded runs Mochi MQTT broker in-process
BrokerModeEmbedded BrokerMode = "embedded"
// BrokerModeExternal connects to external MQTT broker as client
BrokerModeExternal BrokerMode = "external"
)
// Config holds all mqttspec configuration
type Config struct {
// BrokerMode determines whether to use embedded or external broker
BrokerMode BrokerMode
// Broker configuration for embedded mode
Broker BrokerConfig
// ExternalBroker configuration for external client mode
ExternalBroker ExternalBrokerConfig
// Topics configuration
Topics TopicConfig
// QoS configuration for different message types
QoS QoSConfig
// Auth configuration
Auth AuthConfig
// Timeouts for various operations
Timeouts TimeoutConfig
}
// BrokerConfig configures the embedded Mochi MQTT broker
type BrokerConfig struct {
// Host to bind to (default: "localhost")
Host string
// Port to listen on (default: 1883)
Port int
// EnableWebSocket enables WebSocket support
EnableWebSocket bool
// WSPort is the WebSocket port (default: 8883)
WSPort int
// MaxConnections limits concurrent client connections
MaxConnections int
// KeepAlive is the client keepalive interval
KeepAlive time.Duration
// EnableAuth enables username/password authentication
EnableAuth bool
}
// ExternalBrokerConfig for connecting as a client to external broker
type ExternalBrokerConfig struct {
// BrokerURL is the broker address (e.g., tcp://host:port or ssl://host:port)
BrokerURL string
// ClientID is a unique identifier for this handler instance
ClientID string
// Username for MQTT authentication
Username string
// Password for MQTT authentication
Password string
// CleanSession flag (default: true)
CleanSession bool
// KeepAlive interval (default: 60s)
KeepAlive time.Duration
// ConnectTimeout for initial connection (default: 30s)
ConnectTimeout time.Duration
// ReconnectDelay between reconnection attempts (default: 5s)
ReconnectDelay time.Duration
// MaxReconnect attempts (0 = unlimited, default: 0)
MaxReconnect int
// TLSConfig for SSL/TLS connections
TLSConfig *tls.Config
}
// TopicConfig defines the MQTT topic structure
type TopicConfig struct {
// Prefix for all topics (default: "spec")
// Topics will be: {Prefix}/{client_id}/request|response|notify/{sub_id}
Prefix string
}
// QoSConfig defines quality of service levels for different message types
type QoSConfig struct {
// Request messages QoS (default: 1 - at-least-once)
Request byte
// Response messages QoS (default: 1 - at-least-once)
Response byte
// Notification messages QoS (default: 1 - at-least-once)
Notification byte
}
// AuthConfig for MQTT-level authentication
type AuthConfig struct {
// ValidateCredentials is called to validate username/password for embedded broker
// Return true if credentials are valid, false otherwise
ValidateCredentials func(username, password string) bool
}
// TimeoutConfig defines timeouts for various operations
type TimeoutConfig struct {
// Connect timeout for MQTT connection (default: 30s)
Connect time.Duration
// Publish timeout for publishing messages (default: 5s)
Publish time.Duration
// Disconnect timeout for graceful shutdown (default: 10s)
Disconnect time.Duration
}
// DefaultConfig returns a configuration with sensible defaults
func DefaultConfig() *Config {
return &Config{
BrokerMode: BrokerModeEmbedded,
Broker: BrokerConfig{
Host: "localhost",
Port: 1883,
EnableWebSocket: false,
WSPort: 8883,
MaxConnections: 1000,
KeepAlive: 60 * time.Second,
EnableAuth: false,
},
ExternalBroker: ExternalBrokerConfig{
BrokerURL: "",
ClientID: "",
Username: "",
Password: "",
CleanSession: true,
KeepAlive: 60 * time.Second,
ConnectTimeout: 30 * time.Second,
ReconnectDelay: 5 * time.Second,
MaxReconnect: 0, // Unlimited
},
Topics: TopicConfig{
Prefix: "spec",
},
QoS: QoSConfig{
Request: 1, // At-least-once
Response: 1, // At-least-once
Notification: 1, // At-least-once
},
Auth: AuthConfig{
ValidateCredentials: nil,
},
Timeouts: TimeoutConfig{
Connect: 30 * time.Second,
Publish: 5 * time.Second,
Disconnect: 10 * time.Second,
},
}
}

846
pkg/mqttspec/handler.go Normal file
View File

@@ -0,0 +1,846 @@
package mqttspec
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"github.com/google/uuid"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler handles MQTT messages and operations
type Handler struct {
// Database adapter (GORM/Bun)
db common.Database
// Model registry
registry common.ModelRegistry
// Hook registry
hooks *HookRegistry
// Client manager
clientManager *ClientManager
// Subscription manager
subscriptionManager *SubscriptionManager
// Broker interface (embedded or external)
broker BrokerInterface
// Configuration
config *Config
// Context for lifecycle management
ctx context.Context
cancel context.CancelFunc
// Started flag
started bool
mu sync.RWMutex
}
// NewHandler creates a new MQTT handler
func NewHandler(db common.Database, registry common.ModelRegistry, config *Config) (*Handler, error) {
ctx, cancel := context.WithCancel(context.Background())
h := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
clientManager: NewClientManager(ctx),
subscriptionManager: NewSubscriptionManager(),
config: config,
ctx: ctx,
cancel: cancel,
started: false,
}
// Initialize broker based on mode
if config.BrokerMode == BrokerModeEmbedded {
h.broker = NewEmbeddedBroker(config.Broker, h.clientManager)
} else {
h.broker = NewExternalBrokerClient(config.ExternalBroker, h.clientManager)
}
// Set handler reference in broker
h.broker.SetHandler(h)
return h, nil
}
// Start initializes and starts the handler
func (h *Handler) Start() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.started {
return fmt.Errorf("handler already started")
}
// Start broker
if err := h.broker.Start(h.ctx); err != nil {
return fmt.Errorf("failed to start broker: %w", err)
}
// Subscribe to all request topics: spec/+/request
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
if err := h.broker.Subscribe(requestTopic, h.config.QoS.Request, h.handleIncomingMessage); err != nil {
_ = h.broker.Stop(h.ctx)
return fmt.Errorf("failed to subscribe to request topic: %w", err)
}
h.started = true
logger.Info("[MQTTSpec] Handler started, listening on topic: %s", requestTopic)
return nil
}
// Shutdown gracefully shuts down the handler
func (h *Handler) Shutdown() error {
h.mu.Lock()
defer h.mu.Unlock()
if !h.started {
return nil
}
logger.Info("[MQTTSpec] Shutting down handler...")
// Execute disconnect hooks for all clients
h.clientManager.mu.RLock()
clients := make([]*Client, 0, len(h.clientManager.clients))
for _, client := range h.clientManager.clients {
clients = append(clients, client)
}
h.clientManager.mu.RUnlock()
for _, client := range clients {
hookCtx := &HookContext{
Context: h.ctx,
Handler: nil, // Not used for MQTT
Metadata: map[string]interface{}{
"mqtt_client": client,
},
}
_ = h.hooks.Execute(BeforeDisconnect, hookCtx)
h.clientManager.Unregister(client.ID)
_ = h.hooks.Execute(AfterDisconnect, hookCtx)
}
// Unsubscribe from request topic
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
_ = h.broker.Unsubscribe(requestTopic)
// Stop broker
if err := h.broker.Stop(h.ctx); err != nil {
logger.Error("[MQTTSpec] Error stopping broker: %v", err)
}
// Cancel context
if h.cancel != nil {
h.cancel()
}
h.started = false
logger.Info("[MQTTSpec] Handler stopped")
return nil
}
// Hooks returns the hook registry
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// Registry returns the model registry
func (h *Handler) Registry() common.ModelRegistry {
return h.registry
}
// GetDatabase returns the database adapter
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// GetRelationshipInfo is a placeholder for relationship detection
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
// TODO: Implement full relationship detection if needed
return nil
}
// handleIncomingMessage is called when a message arrives on spec/+/request
func (h *Handler) handleIncomingMessage(topic string, payload []byte) {
// Extract client_id from topic: spec/{client_id}/request
parts := strings.Split(topic, "/")
if len(parts) < 3 {
logger.Error("[MQTTSpec] Invalid topic format: %s", topic)
return
}
clientID := parts[len(parts)-2] // Second to last part is client_id
// Parse message
msg, err := ParseMessage(payload)
if err != nil {
logger.Error("[MQTTSpec] Failed to parse message from %s: %v", clientID, err)
h.sendError(clientID, "", "invalid_message", "Failed to parse message")
return
}
// Validate message
if !msg.IsValid() {
logger.Error("[MQTTSpec] Invalid message from %s", clientID)
h.sendError(clientID, msg.ID, "invalid_message", "Message validation failed")
return
}
// Get or register client
client, exists := h.clientManager.GetClient(clientID)
if !exists {
// First request from this client - register it
client = h.clientManager.Register(clientID, "", h)
// Execute connect hooks
hookCtx := &HookContext{
Context: h.ctx,
Handler: nil, // Not used for MQTT, handler ref stored in metadata if needed
Metadata: map[string]interface{}{
"mqtt_client": client,
},
}
if err := h.hooks.Execute(BeforeConnect, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeConnect hook failed for %s: %v", clientID, err)
h.sendError(clientID, msg.ID, "auth_error", err.Error())
h.clientManager.Unregister(clientID)
return
}
_ = h.hooks.Execute(AfterConnect, hookCtx)
}
// Route message by type
switch msg.Type {
case MessageTypeRequest:
h.handleRequest(client, msg)
case MessageTypeSubscription:
h.handleSubscription(client, msg)
case MessageTypePing:
h.handlePing(client, msg)
default:
h.sendError(clientID, msg.ID, "invalid_message_type", fmt.Sprintf("Unknown message type: %s", msg.Type))
}
}
// handleRequest processes CRUD requests
func (h *Handler) handleRequest(client *Client, msg *Message) {
ctx := client.ctx
schema := msg.Schema
entity := msg.Entity
recordID := msg.RecordID
// Get model from registry
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("[MQTTSpec] Model not found for %s.%s: %v", schema, entity, err)
h.sendError(client.ID, msg.ID, "model_not_found", fmt.Sprintf("Model not found: %s.%s", schema, entity))
return
}
// Validate and unwrap model
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
logger.Error("[MQTTSpec] Model validation failed for %s.%s: %v", schema, entity, err)
h.sendError(client.ID, msg.ID, "invalid_model", err.Error())
return
}
model = result.Model
modelPtr := result.ModelPtr
tableName := h.getTableName(schema, entity, model)
// Create hook context
hookCtx := &HookContext{
Context: ctx,
Handler: nil, // Not used for MQTT
Message: msg,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
ModelPtr: modelPtr,
Options: msg.Options,
ID: recordID,
Data: msg.Data,
Metadata: map[string]interface{}{
"mqtt_client": client,
},
}
// Route to operation handler
switch msg.Operation {
case OperationRead:
h.handleRead(client, msg, hookCtx)
case OperationCreate:
h.handleCreate(client, msg, hookCtx)
case OperationUpdate:
h.handleUpdate(client, msg, hookCtx)
case OperationDelete:
h.handleDelete(client, msg, hookCtx)
case OperationMeta:
h.handleMeta(client, msg, hookCtx)
default:
h.sendError(client.ID, msg.ID, "invalid_operation", fmt.Sprintf("Unknown operation: %s", msg.Operation))
}
}
// handleRead processes a read operation
func (h *Handler) handleRead(client *Client, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeRead hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Perform read operation
var data interface{}
var metadata map[string]interface{}
var err error
if hookCtx.ID != "" {
// Read single record by ID
data, err = h.readByID(hookCtx)
metadata = map[string]interface{}{"total": 1}
} else {
// Read multiple records
data, metadata, err = h.readMultiple(hookCtx)
}
if err != nil {
logger.Error("[MQTTSpec] Read operation failed: %v", err)
h.sendError(client.ID, msg.ID, "read_error", err.Error())
return
}
// Update hook context
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
logger.Error("[MQTTSpec] AfterRead hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Send response
h.sendResponse(client.ID, msg.ID, hookCtx.Result, metadata)
}
// handleCreate processes a create operation
func (h *Handler) handleCreate(client *Client, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeCreate hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Perform create operation
data, err := h.create(hookCtx)
if err != nil {
logger.Error("[MQTTSpec] Create operation failed: %v", err)
h.sendError(client.ID, msg.ID, "create_error", err.Error())
return
}
// Update hook context
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("[MQTTSpec] AfterCreate hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Send response
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationCreate, data)
}
// handleUpdate processes an update operation
func (h *Handler) handleUpdate(client *Client, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeUpdate hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Perform update operation
data, err := h.update(hookCtx)
if err != nil {
logger.Error("[MQTTSpec] Update operation failed: %v", err)
h.sendError(client.ID, msg.ID, "update_error", err.Error())
return
}
// Update hook context
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("[MQTTSpec] AfterUpdate hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Send response
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationUpdate, data)
}
// handleDelete processes a delete operation
func (h *Handler) handleDelete(client *Client, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeDelete hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Perform delete operation
if err := h.delete(hookCtx); err != nil {
logger.Error("[MQTTSpec] Delete operation failed: %v", err)
h.sendError(client.ID, msg.ID, "delete_error", err.Error())
return
}
// Execute after hook
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
logger.Error("[MQTTSpec] AfterDelete hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Send response
h.sendResponse(client.ID, msg.ID, map[string]interface{}{"deleted": true}, nil)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationDelete, map[string]interface{}{
"id": hookCtx.ID,
})
}
// handleMeta processes a metadata request
func (h *Handler) handleMeta(client *Client, msg *Message, hookCtx *HookContext) {
metadata, err := h.getMetadata(hookCtx)
if err != nil {
logger.Error("[MQTTSpec] Meta operation failed: %v", err)
h.sendError(client.ID, msg.ID, "meta_error", err.Error())
return
}
h.sendResponse(client.ID, msg.ID, metadata, nil)
}
// handleSubscription manages subscriptions
func (h *Handler) handleSubscription(client *Client, msg *Message) {
switch msg.Operation {
case OperationSubscribe:
h.handleSubscribe(client, msg)
case OperationUnsubscribe:
h.handleUnsubscribe(client, msg)
default:
h.sendError(client.ID, msg.ID, "invalid_subscription_operation", fmt.Sprintf("Unknown subscription operation: %s", msg.Operation))
}
}
// handleSubscribe creates a subscription
func (h *Handler) handleSubscribe(client *Client, msg *Message) {
// Generate subscription ID
subID := uuid.New().String()
// Create hook context
hookCtx := &HookContext{
Context: client.ctx,
Handler: nil, // Not used for MQTT
Message: msg,
Schema: msg.Schema,
Entity: msg.Entity,
Options: msg.Options,
Metadata: map[string]interface{}{
"mqtt_client": client,
},
}
// Execute before hook
if err := h.hooks.Execute(BeforeSubscribe, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeSubscribe hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Create subscription
sub := h.subscriptionManager.Subscribe(subID, client.ID, msg.Schema, msg.Entity, msg.Options)
client.AddSubscription(sub)
// Execute after hook
_ = h.hooks.Execute(AfterSubscribe, hookCtx)
// Send response
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
"subscription_id": subID,
"schema": msg.Schema,
"entity": msg.Entity,
"notify_topic": h.getNotifyTopic(client.ID, subID),
}, nil)
logger.Info("[MQTTSpec] Subscription created: %s for %s.%s (client: %s)", subID, msg.Schema, msg.Entity, client.ID)
}
// handleUnsubscribe removes a subscription
func (h *Handler) handleUnsubscribe(client *Client, msg *Message) {
subID := msg.SubscriptionID
if subID == "" {
h.sendError(client.ID, msg.ID, "invalid_subscription", "Subscription ID is required")
return
}
// Create hook context
hookCtx := &HookContext{
Context: client.ctx,
Handler: nil, // Not used for MQTT
Message: msg,
Metadata: map[string]interface{}{
"mqtt_client": client,
},
}
// Execute before hook
if err := h.hooks.Execute(BeforeUnsubscribe, hookCtx); err != nil {
logger.Error("[MQTTSpec] BeforeUnsubscribe hook failed: %v", err)
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
return
}
// Remove subscription
h.subscriptionManager.Unsubscribe(subID)
client.RemoveSubscription(subID)
// Execute after hook
_ = h.hooks.Execute(AfterUnsubscribe, hookCtx)
// Send response
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
"unsubscribed": true,
"subscription_id": subID,
}, nil)
logger.Info("[MQTTSpec] Subscription removed: %s (client: %s)", subID, client.ID)
}
// handlePing responds to ping messages
func (h *Handler) handlePing(client *Client, msg *Message) {
pong := &ResponseMessage{
ID: msg.ID,
Type: MessageTypePong,
Success: true,
}
payload, _ := json.Marshal(pong)
topic := h.getResponseTopic(client.ID)
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
}
// notifySubscribers sends notifications to subscribers
func (h *Handler) notifySubscribers(schema, entity string, operation OperationType, data interface{}) {
subscriptions := h.subscriptionManager.GetSubscriptionsByEntity(schema, entity)
if len(subscriptions) == 0 {
return
}
for _, sub := range subscriptions {
// Check if data matches subscription filters
if !sub.MatchesFilters(data) {
continue
}
// Get client
client, exists := h.clientManager.GetClient(sub.ConnectionID)
if !exists {
continue
}
// Create notification message
notification := NewNotificationMessage(sub.ID, operation, schema, entity, data)
payload, err := json.Marshal(notification)
if err != nil {
logger.Error("[MQTTSpec] Failed to marshal notification: %v", err)
continue
}
// Publish to client's notify topic
topic := h.getNotifyTopic(client.ID, sub.ID)
if err := h.broker.Publish(topic, h.config.QoS.Notification, payload); err != nil {
logger.Error("[MQTTSpec] Failed to publish notification to %s: %v", topic, err)
}
}
}
// Response helpers
// sendResponse publishes a response message
func (h *Handler) sendResponse(clientID, msgID string, data interface{}, metadata map[string]interface{}) {
resp := NewResponseMessage(msgID, true, data)
resp.Metadata = metadata
payload, err := json.Marshal(resp)
if err != nil {
logger.Error("[MQTTSpec] Failed to marshal response: %v", err)
return
}
topic := h.getResponseTopic(clientID)
if err := h.broker.Publish(topic, h.config.QoS.Response, payload); err != nil {
logger.Error("[MQTTSpec] Failed to publish response to %s: %v", topic, err)
}
}
// sendError publishes an error response
func (h *Handler) sendError(clientID, msgID, code, message string) {
errResp := NewErrorResponse(msgID, code, message)
payload, _ := json.Marshal(errResp)
topic := h.getResponseTopic(clientID)
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
}
// Topic helpers
func (h *Handler) getRequestTopic(clientID string) string {
return fmt.Sprintf("%s/%s/request", h.config.Topics.Prefix, clientID)
}
func (h *Handler) getResponseTopic(clientID string) string {
return fmt.Sprintf("%s/%s/response", h.config.Topics.Prefix, clientID)
}
func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
return fmt.Sprintf("%s/%s/notify/%s", h.config.Topics.Prefix, clientID, subscriptionID)
}
// Database operation helpers (adapted from websocketspec)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Use entity as table name
tableName := entity
if schema != "" {
tableName = schema + "." + tableName
}
return tableName
}
// readByID reads a single record by ID
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
// Apply columns
if hookCtx.Options != nil && len(hookCtx.Options.Columns) > 0 {
query = query.Column(hookCtx.Options.Columns...)
}
// Apply preloads (simplified)
if hookCtx.Options != nil {
for i := range hookCtx.Options.Preload {
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
}
}
// Execute query
if err := query.ScanModel(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to read record: %w", err)
}
return hookCtx.ModelPtr, nil
}
// readMultiple reads multiple records
func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata map[string]interface{}, err error) {
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Apply options
if hookCtx.Options != nil {
// Apply filters
for _, filter := range hookCtx.Options.Filters {
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
}
// Apply sorting
for _, sort := range hookCtx.Options.Sort {
direction := "ASC"
if sort.Direction == "desc" {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Apply limit and offset
if hookCtx.Options.Limit != nil {
query = query.Limit(*hookCtx.Options.Limit)
}
if hookCtx.Options.Offset != nil {
query = query.Offset(*hookCtx.Options.Offset)
}
// Apply preloads
for i := range hookCtx.Options.Preload {
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
}
// Apply columns
if len(hookCtx.Options.Columns) > 0 {
query = query.Column(hookCtx.Options.Columns...)
}
}
// Execute query
if err := query.ScanModel(hookCtx.Context); err != nil {
return nil, nil, fmt.Errorf("failed to read records: %w", err)
}
// Get count
metadata = make(map[string]interface{})
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
if hookCtx.Options != nil {
for _, filter := range hookCtx.Options.Filters {
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
}
}
count, _ := countQuery.Count(hookCtx.Context)
metadata["total"] = count
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
return hookCtx.ModelPtr, metadata, nil
}
// create creates a new record
func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
// Marshal and unmarshal data into model
dataBytes, err := json.Marshal(hookCtx.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal data: %w", err)
}
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
}
// Insert record
query := h.db.NewInsert().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
if _, err := query.Exec(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
return hookCtx.ModelPtr, nil
}
// update updates an existing record
func (h *Handler) update(hookCtx *HookContext) (interface{}, error) {
// Marshal and unmarshal data into model
dataBytes, err := json.Marshal(hookCtx.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal data: %w", err)
}
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
}
// Update record
query := h.db.NewUpdate().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
if _, err := query.Exec(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to update record: %w", err)
}
// Fetch updated record
return h.readByID(hookCtx)
}
// delete deletes a record
func (h *Handler) delete(hookCtx *HookContext) error {
query := h.db.NewDelete().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
if _, err := query.Exec(hookCtx.Context); err != nil {
return fmt.Errorf("failed to delete record: %w", err)
}
return nil
}
// getMetadata returns schema metadata for an entity
func (h *Handler) getMetadata(hookCtx *HookContext) (interface{}, error) {
metadata := make(map[string]interface{})
metadata["schema"] = hookCtx.Schema
metadata["entity"] = hookCtx.Entity
metadata["table_name"] = hookCtx.TableName
// Get fields from model using reflection
columns := reflection.GetModelColumns(hookCtx.Model)
metadata["columns"] = columns
metadata["primary_key"] = reflection.GetPrimaryKeyName(hookCtx.Model)
return metadata, nil
}
// getOperatorSQL converts filter operator to SQL operator
func (h *Handler) getOperatorSQL(operator string) string {
switch operator {
case "eq":
return "="
case "neq":
return "!="
case "gt":
return ">"
case "gte":
return ">="
case "lt":
return "<"
case "lte":
return "<="
case "like":
return "LIKE"
case "ilike":
return "ILIKE"
case "in":
return "IN"
default:
return "="
}
}

View File

@@ -0,0 +1,743 @@
package mqttspec
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// Test model
type TestUser struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
TenantID string `json:"tenant_id"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (TestUser) TableName() string {
return "users"
}
// setupTestHandler creates a handler with in-memory SQLite database
func setupTestHandler(t *testing.T) (*Handler, *gorm.DB) {
// Create in-memory SQLite database
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
// Auto-migrate test model
err = db.AutoMigrate(&TestUser{})
require.NoError(t, err)
// Create handler
config := DefaultConfig()
config.Broker.Port = 21883 // Use different port for handler tests
adapter := database.NewGormAdapter(db)
registry := modelregistry.NewModelRegistry()
registry.RegisterModel("public.users", &TestUser{})
handler, err := NewHandlerWithDatabase(adapter, registry, WithEmbeddedBroker(config.Broker))
require.NoError(t, err)
return handler, db
}
func TestNewHandler(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
assert.NotNil(t, handler)
assert.NotNil(t, handler.db)
assert.NotNil(t, handler.registry)
assert.NotNil(t, handler.hooks)
assert.NotNil(t, handler.clientManager)
assert.NotNil(t, handler.subscriptionManager)
assert.NotNil(t, handler.broker)
assert.NotNil(t, handler.config)
}
func TestHandler_StartShutdown(t *testing.T) {
handler, _ := setupTestHandler(t)
// Start handler
err := handler.Start()
require.NoError(t, err)
assert.True(t, handler.started)
// Shutdown handler
err = handler.Shutdown()
require.NoError(t, err)
assert.False(t, handler.started)
}
func TestHandler_HandleRead_Single(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Insert test data
user := &TestUser{
ID: 1,
Name: "John Doe",
Email: "john@example.com",
Status: "active",
}
db.Create(user)
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create read request message
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
ID: "1",
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle read
handler.handleRead(client, msg, hookCtx)
// Note: In a full integration test, we would verify the response was published
// to the correct MQTT topic. Here we're just testing that the handler doesn't error.
}
func TestHandler_HandleRead_Multiple(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Insert test data
users := []TestUser{
{ID: 1, Name: "User 1", Email: "user1@example.com", Status: "active"},
{ID: 2, Name: "User 2", Email: "user2@example.com", Status: "active"},
{ID: 3, Name: "User 3", Email: "user3@example.com", Status: "inactive"},
}
for _, user := range users {
db.Create(&user)
}
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create read request with filter
msg := &Message{
ID: "msg-2",
Type: MessageTypeRequest,
Operation: OperationRead,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle read
handler.handleRead(client, msg, hookCtx)
}
func TestHandler_HandleCreate(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Start handler to initialize broker
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create request data
newUser := map[string]interface{}{
"name": "New User",
"email": "new@example.com",
"status": "active",
}
// Create create request message
msg := &Message{
ID: "msg-3",
Type: MessageTypeRequest,
Operation: OperationCreate,
Schema: "public",
Entity: "users",
Data: newUser,
Options: &common.RequestOptions{},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
Data: newUser,
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle create
handler.handleCreate(client, msg, hookCtx)
// Verify user was created in database
var user TestUser
result := db.Where("email = ?", "new@example.com").First(&user)
assert.NoError(t, result.Error)
assert.Equal(t, "New User", user.Name)
}
func TestHandler_HandleUpdate(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Insert test data
user := &TestUser{
ID: 1,
Name: "Original Name",
Email: "original@example.com",
Status: "active",
}
db.Create(user)
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Update data
updateData := map[string]interface{}{
"name": "Updated Name",
}
// Create update request message
msg := &Message{
ID: "msg-4",
Type: MessageTypeRequest,
Operation: OperationUpdate,
Schema: "public",
Entity: "users",
Data: updateData,
Options: &common.RequestOptions{},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
ID: "1",
Data: updateData,
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle update
handler.handleUpdate(client, msg, hookCtx)
// Verify user was updated
var updatedUser TestUser
db.First(&updatedUser, 1)
assert.Equal(t, "Updated Name", updatedUser.Name)
}
func TestHandler_HandleDelete(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Insert test data
user := &TestUser{
ID: 1,
Name: "To Delete",
Email: "delete@example.com",
Status: "active",
}
db.Create(user)
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create delete request message
msg := &Message{
ID: "msg-5",
Type: MessageTypeRequest,
Operation: OperationDelete,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
ID: "1",
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle delete
handler.handleDelete(client, msg, hookCtx)
// Verify user was deleted
var deletedUser TestUser
result := db.First(&deletedUser, 1)
assert.Error(t, result.Error)
assert.Equal(t, gorm.ErrRecordNotFound, result.Error)
}
func TestHandler_HandleSubscribe(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create subscribe message
msg := &Message{
ID: "msg-6",
Type: MessageTypeSubscription,
Operation: OperationSubscribe,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
},
}
// Handle subscribe
handler.handleSubscribe(client, msg)
// Verify subscription was created
subscriptions := handler.subscriptionManager.GetSubscriptionsByEntity("public", "users")
assert.Len(t, subscriptions, 1)
assert.Equal(t, client.ID, subscriptions[0].ConnectionID)
}
func TestHandler_HandleUnsubscribe(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create subscription using Subscribe method
sub := handler.subscriptionManager.Subscribe("sub-1", client.ID, "public", "users", &common.RequestOptions{})
client.AddSubscription(sub)
// Create unsubscribe message with subscription ID in Data
msg := &Message{
ID: "msg-7",
Type: MessageTypeSubscription,
Operation: OperationUnsubscribe,
Data: map[string]interface{}{"subscription_id": "sub-1"},
Options: &common.RequestOptions{},
}
// Handle unsubscribe
handler.handleUnsubscribe(client, msg)
// Verify subscription was removed
_, exists := handler.subscriptionManager.GetSubscription("sub-1")
assert.False(t, exists)
}
func TestHandler_NotifySubscribers(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create mock clients
client1 := handler.clientManager.Register("client-1", "user1", handler)
client2 := handler.clientManager.Register("client-2", "user2", handler)
// Create subscriptions
opts1 := &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
}
sub1 := handler.subscriptionManager.Subscribe("sub-1", client1.ID, "public", "users", opts1)
client1.AddSubscription(sub1)
opts2 := &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "inactive"},
},
}
sub2 := handler.subscriptionManager.Subscribe("sub-2", client2.ID, "public", "users", opts2)
client2.AddSubscription(sub2)
// Notify subscribers with active user
activeUser := map[string]interface{}{
"id": 1,
"name": "Active User",
"status": "active",
}
// This should notify sub-1 only
handler.notifySubscribers("public", "users", OperationCreate, activeUser)
// Note: In a full integration test, we would verify that the notification
// was published to the correct MQTT topic. Here we're just testing that
// the handler doesn't error and finds the correct subscriptions.
}
func TestHandler_Hooks_BeforeRead(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Insert test data with different tenants
users := []TestUser{
{ID: 1, Name: "User 1", TenantID: "tenant-a", Status: "active"},
{ID: 2, Name: "User 2", TenantID: "tenant-b", Status: "active"},
{ID: 3, Name: "User 3", TenantID: "tenant-a", Status: "active"},
}
for _, user := range users {
db.Create(&user)
}
// Register hook to filter by tenant
handler.Hooks().Register(BeforeRead, func(ctx *HookContext) error {
// Auto-inject tenant filter
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "tenant_id",
Operator: "eq",
Value: "tenant-a",
})
return nil
})
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create read request (no tenant filter)
msg := &Message{
ID: "msg-8",
Type: MessageTypeRequest,
Operation: OperationRead,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{},
}
// Create hook context
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle read
handler.handleRead(client, msg, hookCtx)
// The hook should have injected the tenant filter
// In a full test, we would verify only tenant-a users were returned
}
func TestHandler_Hooks_BeforeCreate(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Register hook to set default values
handler.Hooks().Register(BeforeCreate, func(ctx *HookContext) error {
// Auto-set tenant_id
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["tenant_id"] = "auto-tenant"
}
return nil
})
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create mock client
client := NewClient("test-client", "test-user", handler)
// Create user without tenant_id
newUser := map[string]interface{}{
"name": "Test User",
"email": "test@example.com",
"status": "active",
}
msg := &Message{
ID: "msg-9",
Type: MessageTypeRequest,
Operation: OperationCreate,
Schema: "public",
Entity: "users",
Data: newUser,
Options: &common.RequestOptions{},
}
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
Data: newUser,
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
// Handle create
handler.handleCreate(client, msg, hookCtx)
// Verify tenant_id was auto-set
var user TestUser
db.Where("email = ?", "test@example.com").First(&user)
assert.Equal(t, "auto-tenant", user.TenantID)
}
func TestHandler_ConcurrentRequests(t *testing.T) {
handler, db := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Create multiple clients
var wg sync.WaitGroup
numClients := 10
for i := 0; i < numClients; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
client := NewClient(fmt.Sprintf("client-%d", id), fmt.Sprintf("user%d", id), handler)
// Create user
newUser := map[string]interface{}{
"name": fmt.Sprintf("User %d", id),
"email": fmt.Sprintf("user%d@example.com", id),
"status": "active",
}
msg := &Message{
ID: fmt.Sprintf("msg-%d", id),
Type: MessageTypeRequest,
Operation: OperationCreate,
Schema: "public",
Entity: "users",
Data: newUser,
Options: &common.RequestOptions{},
}
hookCtx := &HookContext{
Context: context.Background(),
Handler: nil,
Schema: "public",
Entity: "users",
Data: newUser,
Options: msg.Options,
Metadata: map[string]interface{}{"mqtt_client": client},
}
handler.handleCreate(client, msg, hookCtx)
}(i)
}
wg.Wait()
// Verify all users were created
var count int64
db.Model(&TestUser{}).Count(&count)
assert.Equal(t, int64(numClients), count)
}
func TestHandler_TopicHelpers(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
clientID := "test-client"
subscriptionID := "sub-123"
requestTopic := handler.getRequestTopic(clientID)
assert.Equal(t, "spec/test-client/request", requestTopic)
responseTopic := handler.getResponseTopic(clientID)
assert.Equal(t, "spec/test-client/response", responseTopic)
notifyTopic := handler.getNotifyTopic(clientID, subscriptionID)
assert.Equal(t, "spec/test-client/notify/sub-123", notifyTopic)
}
func TestHandler_SendResponse(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Test data
clientID := "test-client"
msgID := "msg-123"
data := map[string]interface{}{"id": 1, "name": "Test"}
metadata := map[string]interface{}{"total": 1}
// Send response (should not error)
handler.sendResponse(clientID, msgID, data, metadata)
}
func TestHandler_SendError(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Test error
clientID := "test-client"
msgID := "msg-123"
code := "test_error"
message := "Test error message"
// Send error (should not error)
handler.sendError(clientID, msgID, code, message)
}
// extractClientID extracts the client ID from a topic like spec/{client_id}/request
func extractClientID(topic string) string {
parts := strings.Split(topic, "/")
if len(parts) >= 2 {
return parts[len(parts)-2]
}
return ""
}
func TestHandler_ExtractClientID(t *testing.T) {
tests := []struct {
topic string
expected string
}{
{"spec/client-123/request", "client-123"},
{"spec/abc-xyz/request", "abc-xyz"},
{"spec/test/request", "test"},
}
for _, tt := range tests {
result := extractClientID(tt.topic)
assert.Equal(t, tt.expected, result, "topic: %s", tt.topic)
}
}
func TestHandler_HandleIncomingMessage_InvalidJSON(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Invalid JSON payload
payload := []byte("{invalid json")
// Should not panic
handler.handleIncomingMessage("spec/test-client/request", payload)
}
func TestHandler_HandleIncomingMessage_ValidMessage(t *testing.T) {
handler, _ := setupTestHandler(t)
defer handler.Shutdown()
// Start handler
err := handler.Start()
require.NoError(t, err)
defer handler.Shutdown()
// Valid message
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{},
}
payload, _ := json.Marshal(msg)
// Should not panic or error
handler.handleIncomingMessage("spec/test-client/request", payload)
}

51
pkg/mqttspec/hooks.go Normal file
View File

@@ -0,0 +1,51 @@
package mqttspec
import (
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
)
// Hook types - aliases to websocketspec for lifecycle hook consistency
type (
// HookType defines the type of lifecycle hook
HookType = websocketspec.HookType
// HookFunc is a function that executes during a lifecycle hook
HookFunc = websocketspec.HookFunc
// HookContext contains all context for hook execution
// Note: For MQTT, the Client is stored in Metadata["mqtt_client"]
HookContext = websocketspec.HookContext
// HookRegistry manages all registered hooks
HookRegistry = websocketspec.HookRegistry
)
// Hook type constants - all 12 lifecycle hooks
const (
// CRUD operation hooks
BeforeRead = websocketspec.BeforeRead
AfterRead = websocketspec.AfterRead
BeforeCreate = websocketspec.BeforeCreate
AfterCreate = websocketspec.AfterCreate
BeforeUpdate = websocketspec.BeforeUpdate
AfterUpdate = websocketspec.AfterUpdate
BeforeDelete = websocketspec.BeforeDelete
AfterDelete = websocketspec.AfterDelete
// Subscription hooks
BeforeSubscribe = websocketspec.BeforeSubscribe
AfterSubscribe = websocketspec.AfterSubscribe
BeforeUnsubscribe = websocketspec.BeforeUnsubscribe
AfterUnsubscribe = websocketspec.AfterUnsubscribe
// Connection hooks
BeforeConnect = websocketspec.BeforeConnect
AfterConnect = websocketspec.AfterConnect
BeforeDisconnect = websocketspec.BeforeDisconnect
AfterDisconnect = websocketspec.AfterDisconnect
)
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return websocketspec.NewHookRegistry()
}

63
pkg/mqttspec/message.go Normal file
View File

@@ -0,0 +1,63 @@
package mqttspec
import (
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
)
// Message types - aliases to websocketspec for protocol consistency
type (
// Message represents an MQTT message (identical to WebSocket message protocol)
Message = websocketspec.Message
// MessageType defines the type of message
MessageType = websocketspec.MessageType
// OperationType defines the operation to perform
OperationType = websocketspec.OperationType
// ResponseMessage is sent back to clients after processing requests
ResponseMessage = websocketspec.ResponseMessage
// NotificationMessage is sent to subscribers when data changes
NotificationMessage = websocketspec.NotificationMessage
// ErrorInfo contains error details
ErrorInfo = websocketspec.ErrorInfo
)
// Message type constants
const (
MessageTypeRequest = websocketspec.MessageTypeRequest
MessageTypeResponse = websocketspec.MessageTypeResponse
MessageTypeNotification = websocketspec.MessageTypeNotification
MessageTypeSubscription = websocketspec.MessageTypeSubscription
MessageTypeError = websocketspec.MessageTypeError
MessageTypePing = websocketspec.MessageTypePing
MessageTypePong = websocketspec.MessageTypePong
)
// Operation type constants
const (
OperationRead = websocketspec.OperationRead
OperationCreate = websocketspec.OperationCreate
OperationUpdate = websocketspec.OperationUpdate
OperationDelete = websocketspec.OperationDelete
OperationSubscribe = websocketspec.OperationSubscribe
OperationUnsubscribe = websocketspec.OperationUnsubscribe
OperationMeta = websocketspec.OperationMeta
)
// Helper functions from websocketspec
var (
// NewResponseMessage creates a new response message
NewResponseMessage = websocketspec.NewResponseMessage
// NewErrorResponse creates an error response
NewErrorResponse = websocketspec.NewErrorResponse
// NewNotificationMessage creates a notification message
NewNotificationMessage = websocketspec.NewNotificationMessage
// ParseMessage parses a JSON message into a Message struct
ParseMessage = websocketspec.ParseMessage
)

104
pkg/mqttspec/mqttspec.go Normal file
View File

@@ -0,0 +1,104 @@
package mqttspec
import (
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"gorm.io/gorm"
"github.com/uptrace/bun"
)
// NewHandlerWithGORM creates an MQTT handler with GORM database adapter
func NewHandlerWithGORM(db *gorm.DB, opts ...Option) (*Handler, error) {
adapter := database.NewGormAdapter(db)
registry := modelregistry.NewModelRegistry()
return NewHandlerWithDatabase(adapter, registry, opts...)
}
// NewHandlerWithBun creates an MQTT handler with Bun database adapter
func NewHandlerWithBun(db *bun.DB, opts ...Option) (*Handler, error) {
adapter := database.NewBunAdapter(db)
registry := modelregistry.NewModelRegistry()
return NewHandlerWithDatabase(adapter, registry, opts...)
}
// NewHandlerWithDatabase creates an MQTT handler with a custom database adapter
func NewHandlerWithDatabase(db common.Database, registry common.ModelRegistry, opts ...Option) (*Handler, error) {
// Start with default configuration
config := DefaultConfig()
// Create handler with basic initialization
// Note: broker and clientManager will be initialized after options are applied
handler, err := NewHandler(db, registry, config)
if err != nil {
return nil, err
}
// Apply functional options
for _, opt := range opts {
if err := opt(handler); err != nil {
return nil, err
}
}
// Reinitialize broker based on final config (after options)
if handler.config.BrokerMode == BrokerModeEmbedded {
handler.broker = NewEmbeddedBroker(handler.config.Broker, handler.clientManager)
} else {
handler.broker = NewExternalBrokerClient(handler.config.ExternalBroker, handler.clientManager)
}
// Set handler reference in broker
handler.broker.SetHandler(handler)
return handler, nil
}
// Option is a functional option for configuring the handler
type Option func(*Handler) error
// WithEmbeddedBroker configures the handler to use an embedded MQTT broker
func WithEmbeddedBroker(config BrokerConfig) Option {
return func(h *Handler) error {
h.config.BrokerMode = BrokerModeEmbedded
h.config.Broker = config
return nil
}
}
// WithExternalBroker configures the handler to connect to an external MQTT broker
func WithExternalBroker(config ExternalBrokerConfig) Option {
return func(h *Handler) error {
h.config.BrokerMode = BrokerModeExternal
h.config.ExternalBroker = config
return nil
}
}
// WithHooks sets a pre-configured hook registry
func WithHooks(hooks *HookRegistry) Option {
return func(h *Handler) error {
h.hooks = hooks
return nil
}
}
// WithTopicPrefix sets a custom topic prefix (default: "spec")
func WithTopicPrefix(prefix string) Option {
return func(h *Handler) error {
h.config.Topics.Prefix = prefix
return nil
}
}
// WithQoS sets custom QoS levels for different message types
func WithQoS(request, response, notification byte) Option {
return func(h *Handler) error {
h.config.QoS.Request = request
h.config.QoS.Response = response
h.config.QoS.Notification = notification
return nil
}
}

View File

@@ -0,0 +1,21 @@
package mqttspec
import (
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
)
// Subscription types - aliases to websocketspec for subscription management
type (
// Subscription represents an active subscription to entity changes
// The key difference for MQTT: notifications are delivered via MQTT publish
// to spec/{client_id}/notify/{subscription_id} instead of WebSocket send
Subscription = websocketspec.Subscription
// SubscriptionManager manages all active subscriptions
SubscriptionManager = websocketspec.SubscriptionManager
)
// NewSubscriptionManager creates a new subscription manager
func NewSubscriptionManager() *SubscriptionManager {
return websocketspec.NewSubscriptionManager()
}

View File

@@ -273,25 +273,151 @@ handler.SetOpenAPIGenerator(func() (string, error) {
})
```
## Using with Swagger UI
## Using the Built-in UI Handler
You can serve the generated OpenAPI spec with Swagger UI:
The package includes a built-in UI handler that serves popular OpenAPI visualization tools. No need to download or manage static files - everything is served from CDN.
### Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/openapi"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Setup your API routes and OpenAPI generator...
// (see examples above)
// Add the UI handler - defaults to Swagger UI
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
SpecURL: "/openapi",
Title: "My API Documentation",
})
// Now visit http://localhost:8080/docs
http.ListenAndServe(":8080", router)
}
```
### Supported UI Frameworks
The handler supports four popular OpenAPI UI frameworks:
#### 1. Swagger UI (Default)
The most widely used OpenAPI UI with excellent compatibility and features.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
Theme: "dark", // optional: "light" or "dark"
})
```
#### 2. RapiDoc
Modern, customizable, and feature-rich OpenAPI UI.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.RapiDoc,
Theme: "dark",
})
```
#### 3. Redoc
Clean, responsive documentation with great UX.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.Redoc,
})
```
#### 4. Scalar
Modern and sleek OpenAPI documentation.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.Scalar,
Theme: "dark",
})
```
### Configuration Options
```go
type UIConfig struct {
UIType UIType // SwaggerUI, RapiDoc, Redoc, or Scalar
SpecURL string // URL to OpenAPI spec (default: "/openapi")
Title string // Page title (default: "API Documentation")
FaviconURL string // Custom favicon URL (optional)
CustomCSS string // Custom CSS to inject (optional)
Theme string // "light" or "dark" (support varies by UI)
}
```
### Custom Styling Example
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
Title: "Acme Corp API",
CustomCSS: `
.swagger-ui .topbar {
background-color: #1976d2;
}
.swagger-ui .info .title {
color: #1976d2;
}
`,
})
```
### Using Multiple UIs
You can serve different UIs at different paths:
```go
// Swagger UI at /docs
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
})
// Redoc at /redoc
openapi.SetupUIRoute(router, "/redoc", openapi.UIConfig{
UIType: openapi.Redoc,
})
// RapiDoc at /api-docs
openapi.SetupUIRoute(router, "/api-docs", openapi.UIConfig{
UIType: openapi.RapiDoc,
})
```
### Manual Handler Usage
If you need more control, use the handler directly:
```go
handler := openapi.UIHandler(openapi.UIConfig{
UIType: openapi.SwaggerUI,
SpecURL: "/api/openapi.json",
})
router.Handle("/documentation", handler)
```
## Using with External Swagger UI
Alternatively, you can use an external Swagger UI instance:
1. Get the spec from `/openapi`
2. Load it in Swagger UI at `https://petstore.swagger.io/`
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
Example with self-hosted Swagger UI:
```go
// Serve Swagger UI static files
router.PathPrefix("/swagger/").Handler(
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
)
// Configure Swagger UI to use /openapi
```
## Testing
You can test the OpenAPI endpoint:

View File

@@ -183,6 +183,69 @@ func ExampleWithFuncSpec() {
_ = generatorFunc
}
// ExampleWithUIHandler shows how to serve OpenAPI documentation with a web UI
func ExampleWithUIHandler(db *gorm.DB) {
// Create handler and configure OpenAPI generator
handler := restheadspec.NewHandlerWithGORM(db)
registry := modelregistry.NewModelRegistry()
handler.SetOpenAPIGenerator(func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My API",
Description: "API documentation with interactive UI",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: registry,
IncludeRestheadSpec: true,
})
return generator.GenerateJSON()
})
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Add UI handlers for different frameworks
// Swagger UI at /docs (most popular)
SetupUIRoute(router, "/docs", UIConfig{
UIType: SwaggerUI,
SpecURL: "/openapi",
Title: "My API - Swagger UI",
Theme: "light",
})
// RapiDoc at /rapidoc (modern alternative)
SetupUIRoute(router, "/rapidoc", UIConfig{
UIType: RapiDoc,
SpecURL: "/openapi",
Title: "My API - RapiDoc",
})
// Redoc at /redoc (clean and responsive)
SetupUIRoute(router, "/redoc", UIConfig{
UIType: Redoc,
SpecURL: "/openapi",
Title: "My API - Redoc",
})
// Scalar at /scalar (modern and sleek)
SetupUIRoute(router, "/scalar", UIConfig{
UIType: Scalar,
SpecURL: "/openapi",
Title: "My API - Scalar",
Theme: "dark",
})
// Now you can access:
// http://localhost:8080/docs - Swagger UI
// http://localhost:8080/rapidoc - RapiDoc
// http://localhost:8080/redoc - Redoc
// http://localhost:8080/scalar - Scalar
// http://localhost:8080/openapi - Raw OpenAPI JSON
_ = router
}
// ExampleCustomization shows advanced customization options
func ExampleCustomization() {
// Create registry and register models with descriptions using struct tags

294
pkg/openapi/ui_handler.go Normal file
View File

@@ -0,0 +1,294 @@
package openapi
import (
"fmt"
"html/template"
"net/http"
"strings"
"github.com/gorilla/mux"
)
// UIType represents the type of OpenAPI UI to serve
type UIType string
const (
// SwaggerUI is the most popular OpenAPI UI
SwaggerUI UIType = "swagger-ui"
// RapiDoc is a modern, customizable OpenAPI UI
RapiDoc UIType = "rapidoc"
// Redoc is a clean, responsive OpenAPI UI
Redoc UIType = "redoc"
// Scalar is a modern and sleek OpenAPI UI
Scalar UIType = "scalar"
)
// UIConfig holds configuration for the OpenAPI UI handler
type UIConfig struct {
// UIType specifies which UI framework to use (default: SwaggerUI)
UIType UIType
// SpecURL is the URL to the OpenAPI spec JSON (default: "/openapi")
SpecURL string
// Title is the page title (default: "API Documentation")
Title string
// FaviconURL is the URL to the favicon (optional)
FaviconURL string
// CustomCSS allows injecting custom CSS (optional)
CustomCSS string
// Theme for the UI (light/dark, depends on UI type)
Theme string
}
// UIHandler creates an HTTP handler that serves an OpenAPI UI
func UIHandler(config UIConfig) http.HandlerFunc {
// Set defaults
if config.UIType == "" {
config.UIType = SwaggerUI
}
if config.SpecURL == "" {
config.SpecURL = "/openapi"
}
if config.Title == "" {
config.Title = "API Documentation"
}
if config.Theme == "" {
config.Theme = "light"
}
return func(w http.ResponseWriter, r *http.Request) {
var htmlContent string
var err error
switch config.UIType {
case SwaggerUI:
htmlContent, err = generateSwaggerUI(config)
case RapiDoc:
htmlContent, err = generateRapiDoc(config)
case Redoc:
htmlContent, err = generateRedoc(config)
case Scalar:
htmlContent, err = generateScalar(config)
default:
http.Error(w, "Unsupported UI type", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Failed to generate UI: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte(htmlContent))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to write response: %v", err), http.StatusInternalServerError)
return
}
}
}
// templateData wraps UIConfig to properly handle CSS in templates
type templateData struct {
UIConfig
SafeCustomCSS template.CSS
}
// generateSwaggerUI generates the HTML for Swagger UI
func generateSwaggerUI(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; }
*, *:before, *:after { box-sizing: inherit; }
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-standalone-preset.js"></script>
<script>
window.onload = function() {
const ui = SwaggerUIBundle({
url: "{{.SpecURL}}",
dom_id: '#swagger-ui',
deepLinking: true,
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
layout: "StandaloneLayout",
{{if eq .Theme "dark"}}
syntaxHighlight: {
activate: true,
theme: "monokai"
}
{{end}}
});
window.ui = ui;
};
</script>
</body>
</html>`
t, err := template.New("swagger").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateRapiDoc generates the HTML for RapiDoc
func generateRapiDoc(config UIConfig) (string, error) {
theme := "light"
if config.Theme == "dark" {
theme = "dark"
}
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
</head>
<body>
<rapi-doc
spec-url="{{.SpecURL}}"
theme="` + theme + `"
render-style="read"
show-header="true"
show-info="true"
allow-try="true"
allow-server-selection="true"
allow-authentication="true"
api-key-name="Authorization"
api-key-location="header"
></rapi-doc>
</body>
</html>`
t, err := template.New("rapidoc").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateRedoc generates the HTML for Redoc
func generateRedoc(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<redoc spec-url="{{.SpecURL}}" {{if eq .Theme "dark"}}theme='{"colors": {"primary": {"main": "#dd5522"}}}'{{end}}></redoc>
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
</body>
</html>`
t, err := template.New("redoc").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateScalar generates the HTML for Scalar
func generateScalar(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<script id="api-reference" data-url="{{.SpecURL}}" {{if eq .Theme "dark"}}data-theme="dark"{{end}}></script>
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
</body>
</html>`
t, err := template.New("scalar").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// SetupUIRoute adds the OpenAPI UI route to a mux router
// This is a convenience function for the most common use case
func SetupUIRoute(router *mux.Router, path string, config UIConfig) {
router.Handle(path, UIHandler(config))
}

View File

@@ -0,0 +1,308 @@
package openapi
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
)
func TestUIHandler_SwaggerUI(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
SpecURL: "/openapi",
Title: "Test API Docs",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Swagger UI specific content
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected Swagger UI content")
}
if !strings.Contains(body, "SwaggerUIBundle") {
t.Error("Expected SwaggerUIBundle script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
if !strings.Contains(body, "swagger-ui-dist") {
t.Error("Expected Swagger UI CDN link")
}
}
func TestUIHandler_RapiDoc(t *testing.T) {
config := UIConfig{
UIType: RapiDoc,
SpecURL: "/api/spec",
Title: "RapiDoc Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for RapiDoc specific content
if !strings.Contains(body, "rapi-doc") {
t.Error("Expected rapi-doc element")
}
if !strings.Contains(body, "rapidoc-min.js") {
t.Error("Expected RapiDoc script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_Redoc(t *testing.T) {
config := UIConfig{
UIType: Redoc,
SpecURL: "/spec.json",
Title: "Redoc Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Redoc specific content
if !strings.Contains(body, "<redoc") {
t.Error("Expected redoc element")
}
if !strings.Contains(body, "redoc.standalone.js") {
t.Error("Expected Redoc script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_Scalar(t *testing.T) {
config := UIConfig{
UIType: Scalar,
SpecURL: "/openapi.json",
Title: "Scalar Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Scalar specific content
if !strings.Contains(body, "api-reference") {
t.Error("Expected api-reference element")
}
if !strings.Contains(body, "@scalar/api-reference") {
t.Error("Expected Scalar script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_DefaultValues(t *testing.T) {
// Test with empty config to check defaults
config := UIConfig{}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Should default to Swagger UI
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected default to Swagger UI")
}
// Should default to /openapi spec URL
if !strings.Contains(body, "/openapi") {
t.Error("Expected default spec URL '/openapi'")
}
// Should default to "API Documentation" title
if !strings.Contains(body, "API Documentation") {
t.Error("Expected default title 'API Documentation'")
}
}
func TestUIHandler_CustomCSS(t *testing.T) {
customCSS := ".custom-class { color: red; }"
config := UIConfig{
UIType: SwaggerUI,
CustomCSS: customCSS,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
if !strings.Contains(body, customCSS) {
t.Errorf("Expected custom CSS to be included. Body:\n%s", body)
}
}
func TestUIHandler_Favicon(t *testing.T) {
faviconURL := "https://example.com/favicon.ico"
config := UIConfig{
UIType: SwaggerUI,
FaviconURL: faviconURL,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
if !strings.Contains(body, faviconURL) {
t.Error("Expected favicon URL to be included")
}
}
func TestUIHandler_DarkTheme(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
Theme: "dark",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
// SwaggerUI uses monokai theme for dark mode
if !strings.Contains(body, "monokai") {
t.Error("Expected dark theme configuration for Swagger UI")
}
}
func TestUIHandler_InvalidUIType(t *testing.T) {
config := UIConfig{
UIType: "invalid-ui-type",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid UI type, got %d", resp.StatusCode)
}
}
func TestUIHandler_ContentType(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
contentType := w.Header().Get("Content-Type")
if !strings.Contains(contentType, "text/html") {
t.Errorf("Expected Content-Type to contain 'text/html', got '%s'", contentType)
}
if !strings.Contains(contentType, "charset=utf-8") {
t.Errorf("Expected Content-Type to contain 'charset=utf-8', got '%s'", contentType)
}
}
func TestSetupUIRoute(t *testing.T) {
router := mux.NewRouter()
config := UIConfig{
UIType: SwaggerUI,
}
SetupUIRoute(router, "/api-docs", config)
// Test that the route was added and works
req := httptest.NewRequest("GET", "/api-docs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Verify it returns HTML
body := w.Body.String()
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected Swagger UI content")
}
}

703
pkg/resolvespec/README.md Normal file
View File

@@ -0,0 +1,703 @@
# ResolveSpec - Body-Based REST API
ResolveSpec provides a REST API where query options are passed in the JSON request body. This approach offers GraphQL-like flexibility while maintaining RESTful principles, making it ideal for complex queries and operations.
## Features
* **Body-Based Querying**: All query options passed via JSON request body
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
* **Offset Pagination**: Traditional limit/offset pagination support
* **Advanced Filtering**: Multiple operators, AND/OR logic, and custom SQL
* **Relationship Preloading**: Load related entities with custom column selection and filters
* **Recursive CRUD**: Automatically handle nested object graphs with foreign key resolution
* **Computed Columns**: Define virtual columns with SQL expressions
* **Database-Agnostic**: Works with GORM, Bun, or custom database adapters
* **Router-Agnostic**: Integrates with any HTTP router through standard interfaces
* **Type-Safe**: Strong type validation and conversion
## Quick Start
### Setup with GORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
import "github.com/gorilla/mux"
// Create handler
handler := resolvespec.NewHandlerWithGORM(db)
// IMPORTANT: Register models BEFORE setting up routes
handler.registry.RegisterModel("core.users", &User{})
handler.registry.RegisterModel("core.posts", &Post{})
// Setup routes
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
// Start server
http.ListenAndServe(":8080", router)
```
### Setup with Bun ORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
import "github.com/uptrace/bun"
// Create handler with Bun
handler := resolvespec.NewHandlerWithBun(bunDB)
// Register models
handler.registry.RegisterModel("core.users", &User{})
// Setup routes (same as GORM)
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
```
## Basic Usage
### Simple Read Request
```http
POST /core/users HTTP/1.1
Content-Type: application/json
```
### With Preloading
```http
POST /core/users HTTP/1.1
Content-Type: application/json
```
## Request Structure
### Request Format
```json
{
"operation": "read|create|update|delete",
"data": {
// For create/update operations
},
"options": {
"columns": [...],
"preload": [...],
"filters": [...],
"sort": [...],
"limit": number,
"offset": number,
"cursor_forward": "string",
"cursor_backward": "string",
"customOperators": [...],
"computedColumns": [...]
}
}
```
### Operations
| Operation | Description | Requires Data | Requires ID |
|-----------|-------------|---------------|-------------|
| `read` | Fetch records | No | Optional (single record) |
| `create` | Create new record(s) | Yes | No |
| `update` | Update existing record(s) | Yes | Yes (in URL) |
| `delete` | Delete record(s) | No | Yes (in URL) |
### Options Fields
| Field | Type | Description | Example |
|-------|------|-------------|---------|
| `columns` | `[]string` | Columns to select | `["id", "name", "email"]` |
| `preload` | `[]PreloadConfig` | Relations to load | See [Preloading](#preloading) |
| `filters` | `[]Filter` | Filter conditions | See [Filtering](#filtering) |
| `sort` | `[]Sort` | Sort criteria | `[{"column": "created_at", "direction": "desc"}]` |
| `limit` | `int` | Max records to return | `50` |
| `offset` | `int` | Number of records to skip | `100` |
| `cursor_forward` | `string` | Cursor for next page | `"12345"` |
| `cursor_backward` | `string` | Cursor for previous page | `"12300"` |
| `customOperators` | `[]CustomOperator` | Custom SQL conditions | See [Custom Operators](#custom-operators) |
| `computedColumns` | `[]ComputedColumn` | Virtual columns | See [Computed Columns](#computed-columns) |
## Filtering
### Available Operators
| Operator | Description | Example |
|----------|-------------|---------|
| `eq` | Equal | `{"column": "status", "operator": "eq", "value": "active"}` |
| `neq` | Not Equal | `{"column": "status", "operator": "neq", "value": "deleted"}` |
| `gt` | Greater Than | `{"column": "age", "operator": "gt", "value": 18}` |
| `gte` | Greater Than or Equal | `{"column": "age", "operator": "gte", "value": 18}` |
| `lt` | Less Than | `{"column": "price", "operator": "lt", "value": 100}` |
| `lte` | Less Than or Equal | `{"column": "price", "operator": "lte", "value": 100}` |
| `like` | LIKE pattern | `{"column": "name", "operator": "like", "value": "%john%"}` |
| `ilike` | Case-insensitive LIKE | `{"column": "email", "operator": "ilike", "value": "%@example.com"}` |
| `in` | IN clause | `{"column": "status", "operator": "in", "value": ["active", "pending"]}` |
| `contains` | Contains string | `{"column": "description", "operator": "contains", "value": "important"}` |
| `startswith` | Starts with string | `{"column": "name", "operator": "startswith", "value": "John"}` |
| `endswith` | Ends with string | `{"column": "email", "operator": "endswith", "value": "@example.com"}` |
| `between` | Between (exclusive) | `{"column": "age", "operator": "between", "value": [18, 65]}` |
| `betweeninclusive` | Between (inclusive) | `{"column": "price", "operator": "betweeninclusive", "value": [10, 100]}` |
| `empty` | IS NULL or empty | `{"column": "deleted_at", "operator": "empty"}` |
| `notempty` | IS NOT NULL | `{"column": "email", "operator": "notempty"}` |
### Complex Filtering Example
```json
{
"operation": "read",
"options": {
"filters": [
{
"column": "status",
"operator": "eq",
"value": "active"
},
{
"column": "age",
"operator": "gte",
"value": 18
},
{
"column": "email",
"operator": "ilike",
"value": "%@company.com"
}
]
}
}
```
## Preloading
Load related entities with custom configuration:
```json
{
"operation": "read",
"options": {
"columns": ["id", "name", "email"],
"preload": [
{
"relation": "posts",
"columns": ["id", "title", "created_at"],
"filters": [
{
"column": "status",
"operator": "eq",
"value": "published"
}
],
"sort": [
{
"column": "created_at",
"direction": "desc"
}
],
"limit": 5
},
{
"relation": "profile",
"columns": ["bio", "website"]
}
]
}
}
```
## Cursor Pagination
Efficient pagination for large datasets:
### First Request (No Cursor)
```json
{
"operation": "read",
"options": {
"sort": [
{
"column": "created_at",
"direction": "desc"
},
{
"column": "id",
"direction": "asc"
}
],
"limit": 50
}
}
```
### Next Page (Forward Cursor)
```json
{
"operation": "read",
"options": {
"sort": [
{
"column": "created_at",
"direction": "desc"
},
{
"column": "id",
"direction": "asc"
}
],
"limit": 50,
"cursor_forward": "12345"
}
}
```
### Previous Page (Backward Cursor)
```json
{
"operation": "read",
"options": {
"sort": [
{
"column": "created_at",
"direction": "desc"
},
{
"column": "id",
"direction": "asc"
}
],
"limit": 50,
"cursor_backward": "12300"
}
}
```
**Benefits over offset pagination**:
* Consistent results when data changes
* Better performance for large offsets
* Prevents "skipped" or duplicate records
* Works with complex sort expressions
## Recursive CRUD Operations
Automatically handle nested object graphs with intelligent foreign key resolution.
### Creating Nested Objects
```json
{
"operation": "create",
"data": {
"name": "John Doe",
"email": "john@example.com",
"posts": [
{
"title": "My First Post",
"content": "Hello World",
"tags": [
{"name": "tech"},
{"name": "programming"}
]
},
{
"title": "Second Post",
"content": "More content"
}
],
"profile": {
"bio": "Software Developer",
"website": "https://example.com"
}
}
}
```
### Per-Record Operation Control with `_request`
Control individual operations for each nested record:
```json
{
"operation": "update",
"data": {
"name": "John Updated",
"posts": [
{
"_request": "insert",
"title": "New Post",
"content": "Fresh content"
},
{
"_request": "update",
"id": 456,
"title": "Updated Post Title"
},
{
"_request": "delete",
"id": 789
}
]
}
}
```
**Supported `_request` values**:
* `insert` - Create a new related record
* `update` - Update an existing related record
* `delete` - Delete a related record
* `upsert` - Create if doesn't exist, update if exists
**How It Works**:
1. Automatic foreign key resolution - parent IDs propagate to children
2. Recursive processing - handles nested relationships at any depth
3. Transaction safety - all operations execute atomically
4. Relationship detection - automatically detects belongsTo, hasMany, hasOne, many2many
5. Flexible operations - mix create, update, and delete in one request
## Computed Columns
Define virtual columns using SQL expressions:
```json
{
"operation": "read",
"options": {
"columns": ["id", "first_name", "last_name"],
"computedColumns": [
{
"name": "full_name",
"expression": "CONCAT(first_name, ' ', last_name)"
},
{
"name": "age_years",
"expression": "EXTRACT(YEAR FROM AGE(birth_date))"
}
]
}
}
```
## Custom Operators
Add custom SQL conditions when needed:
```json
{
"operation": "read",
"options": {
"customOperators": [
{
"condition": "LOWER(email) LIKE ?",
"values": ["%@example.com"]
},
{
"condition": "created_at > NOW() - INTERVAL '7 days'"
}
]
}
}
```
## Lifecycle Hooks
Register hooks for all CRUD operations:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create handler
handler := resolvespec.NewHandlerWithGORM(db)
// Register a before-read hook (e.g., for authorization)
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
// Check permissions
if !userHasPermission(ctx.Context, ctx.Entity) {
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
}
// Modify query options
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
ctx.Options.Limit = ptr(100) // Enforce max limit
}
return nil
})
// Register an after-read hook (e.g., for data transformation)
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
// Transform or filter results
if users, ok := ctx.Result.([]User); ok {
for i := range users {
users[i].Email = maskEmail(users[i].Email)
}
}
return nil
})
// Register a before-create hook (e.g., for validation)
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
// Validate data
if user, ok := ctx.Data.(*User); ok {
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Add timestamps
user.CreatedAt = time.Now()
}
return nil
})
```
**Available Hook Types**:
* `BeforeRead`, `AfterRead`
* `BeforeCreate`, `AfterCreate`
* `BeforeUpdate`, `AfterUpdate`
* `BeforeDelete`, `AfterDelete`
**HookContext** provides:
* `Context`: Request context
* `Handler`: Access to handler, database, and registry
* `Schema`, `Entity`, `TableName`: Request info
* `Model`: The registered model type
* `Options`: Parsed request options (filters, sorting, etc.)
* `ID`: Record ID (for single-record operations)
* `Data`: Request data (for create/update)
* `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
## Model Registration
```go
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
Profile *Profile `json:"profile,omitempty" gorm:"foreignKey:UserID"`
}
type Post struct {
ID uint `json:"id" gorm:"primaryKey"`
UserID uint `json:"user_id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
}
// Schema.Table format
handler.registry.RegisterModel("core.users", &User{})
handler.registry.RegisterModel("core.posts", &Post{})
```
## Complete Example
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
}
type Post struct {
ID uint `json:"id" gorm:"primaryKey"`
UserID uint `json:"user_id"`
Title string `json:"title"`
Content string `json:"content"`
Status string `json:"status"`
}
func main() {
// Connect to database
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
if err != nil {
log.Fatal(err)
}
// Create handler
handler := resolvespec.NewHandlerWithGORM(db)
// Register models
handler.registry.RegisterModel("core.users", &User{})
handler.registry.RegisterModel("core.posts", &Post{})
// Add hooks
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
log.Printf("Reading %s", ctx.Entity)
return nil
})
// Setup routes
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
// Start server
log.Println("Server starting on :8080")
log.Fatal(http.ListenAndServe(":8080", router))
}
```
## Testing
ResolveSpec is designed for testability:
```go
import (
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
)
func TestUserRead(t *testing.T) {
handler := resolvespec.NewHandlerWithGORM(testDB)
handler.registry.RegisterModel("core.users", &User{})
reqBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"columns": []string{"id", "name"},
"limit": 10,
},
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/core/users", bytes.NewReader(body))
rec := httptest.NewRecorder()
// Test your handler...
}
```
## Router Integration
### Gorilla Mux
```go
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
```
### BunRouter
```go
router := bunrouter.New()
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
```
### Custom Routers
```go
// Implement custom integration using common.Request and common.ResponseWriter
router.POST("/:schema/:entity", func(w http.ResponseWriter, r *http.Request) {
params := extractParams(r) // Your param extraction logic
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
})
```
## Response Format
### Success Response
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 50,
"limit": 10,
"offset": 0
}
}
```
### Error Response
```json
{
"success": false,
"error": {
"code": "validation_error",
"message": "Invalid request",
"details": "..."
}
}
```
## See Also
* [Main README](../../README.md) - ResolveSpec overview
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
## License
This package is part of ResolveSpec and is licensed under the MIT License.
```
## Response Format
### Success Response
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 50,
"limit": 10,
"offset": 0
}
}
```
### Error Response
```json
{
"success": false,
"error": {
"code": "validation_error",
"message": "Invalid request",
"details": "..."
}
}
```
## See Also
* [Main README](../../README.md) - ResolveSpec overview
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
## License
This package is part of ResolveSpec and is licensed under the MIT License.

View File

@@ -0,0 +1,118 @@
package resolvespec
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// queryCacheKey represents the components used to build a cache key for query total count
type queryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// cachedTotal represents a cached total count
type cachedTotal struct {
Total int `json:"total"`
}
// buildQueryCacheKey builds a cache key from query parameters for total count caching
func buildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := queryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// buildExtendedQueryCacheKey builds a cache key for extended query options with cursor pagination
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, cursorFwd, cursorBwd string) string {
key := queryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%s_%s",
tableName, filters, sort, customWhere, customOr, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func getQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// buildCacheTags creates cache tags from schema and table name
func buildCacheTags(schema, tableName string) []string {
return []string{
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
}
}
// setQueryTotalCache stores a query total in the cache with schema and table tags
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
c := cache.GetDefaultCache()
cacheData := cachedTotal{Total: total}
tags := buildCacheTags(schema, tableName)
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
}
// invalidateCacheForTags removes all cached items matching the specified tags
func invalidateCacheForTags(ctx context.Context, tags []string) error {
c := cache.GetDefaultCache()
// Invalidate for each tag
for _, tag := range tags {
if err := c.DeleteByTag(ctx, tag); err != nil {
return err
}
}
return nil
}

View File

@@ -331,19 +331,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Use extended cache key if cursors are present
var cacheKeyHash string
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
cacheKeyHash = buildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
nil, // No expand options in resolvespec
false, // distinct not used here
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
options.CursorForward,
options.CursorBackward,
)
} else {
cacheKeyHash = cache.BuildQueryCacheKey(
cacheKeyHash = buildQueryCacheKey(
tableName,
options.Filters,
options.Sort,
@@ -351,10 +349,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
"", // No custom SQL OR in resolvespec
)
}
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
cacheKey := getQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
var cachedTotal cache.CachedTotal
var cachedTotal cachedTotal
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err == nil {
total = cachedTotal.Total
@@ -371,10 +369,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache
// Store in cache with schema and table tags
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
@@ -464,6 +461,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, result.Data, nil)
return
}
@@ -480,6 +482,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, v, nil)
case []map[string]interface{}:
@@ -518,6 +525,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created %d records with nested data", len(results))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, results, nil)
return
}
@@ -541,6 +553,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created %d records", len(v))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, v, nil)
case []interface{}:
@@ -584,6 +601,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created %d records with nested data", len(results))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, results, nil)
return
}
@@ -611,6 +633,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
logger.Info("Successfully created %d records", len(v))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, list, nil)
default:
@@ -661,6 +688,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return
}
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, result.Data, nil)
return
}
@@ -697,6 +729,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
}
logger.Info("Successfully updated %d records", result.RowsAffected())
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, data, nil)
case []map[string]interface{}:
@@ -735,6 +772,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return
}
logger.Info("Successfully updated %d records with nested data", len(results))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, results, nil)
return
}
@@ -758,6 +800,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return
}
logger.Info("Successfully updated %d records", len(updates))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, updates, nil)
case []interface{}:
@@ -800,6 +847,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return
}
logger.Info("Successfully updated %d records with nested data", len(results))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, results, nil)
return
}
@@ -827,6 +879,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return
}
logger.Info("Successfully updated %d records", len(list))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, list, nil)
default:
@@ -873,6 +930,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", len(v))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
return
@@ -914,6 +976,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", deletedCount)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
return
@@ -940,6 +1007,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", deletedCount)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
return
@@ -998,6 +1070,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
logger.Info("Successfully deleted record with ID: %s", id)
// Return the deleted record data
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, recordToDelete, nil)
}

445
pkg/restheadspec/README.md Normal file
View File

@@ -0,0 +1,445 @@
# RestHeadSpec - Header-Based REST API
RestHeadSpec provides a REST API where all query options are passed via HTTP headers instead of the request body. This provides cleaner separation between data and metadata, making it ideal for GET requests and RESTful architectures.
## Features
* **Header-Based Querying**: All query options via HTTP headers
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
* **Advanced Filtering**: Field filters, search operators, AND/OR logic
* **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
* **Single Record as Object**: Automatically return single-element arrays as objects (default)
* **Base64 Support**: Base64-encoded header values for complex queries
* **Type-Aware Filtering**: Automatic type detection and conversion
* **CORS Support**: Comprehensive CORS headers for cross-origin requests
* **OPTIONS Method**: Full OPTIONS support for CORS preflight
## Quick Start
### Setup with GORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/gorilla/mux"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// IMPORTANT: Register models BEFORE setting up routes
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server
http.ListenAndServe(":8080", router)
```
### Setup with Bun ORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/uptrace/bun"
// Create handler with Bun
handler := restheadspec.NewHandlerWithBun(bunDB)
// Register models
handler.Registry.RegisterModel("public.users", &User{})
// Setup routes (same as GORM)
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
```
## Basic Usage
### Simple GET Request
```http
GET /public/users HTTP/1.1
Host: api.example.com
X-Select-Fields: id,name,email
X-FieldFilter-Status: active
X-Sort: -created_at
X-Limit: 50
```
### With Preloading
```http
GET /public/users HTTP/1.1
X-Select-Fields: id,name,email,department_id
X-Preload: department:id,name
X-FieldFilter-Status: active
X-Limit: 50
```
## Common Headers
| Header | Description | Example |
|--------|-------------|---------|
| `X-Select-Fields` | Columns to include | `id,name,email` |
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
| `X-Preload` | Preload relations | `posts:id,title` |
| `X-Sort` | Sort columns | `-created_at,+name` |
| `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` |
| `X-Single-Record-As-Object` | Return single records as objects | `false` |
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
For complete header documentation, see [HEADERS.md](HEADERS.md).
## Lifecycle Hooks
RestHeadSpec supports lifecycle hooks for all CRUD operations:
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register a before-read hook (e.g., for authorization)
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// Check permissions
if !userHasPermission(ctx.Context, ctx.Entity) {
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
}
// Modify query options
ctx.Options.Limit = ptr(100) // Enforce max limit
return nil
})
// Register an after-read hook (e.g., for data transformation)
handler.Hooks.Register(restheadspec.AfterRead, func(ctx *restheadspec.HookContext) error {
// Transform or filter results
if users, ok := ctx.Result.([]User); ok {
for i := range users {
users[i].Email = maskEmail(users[i].Email)
}
}
return nil
})
// Register a before-create hook (e.g., for validation)
handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookContext) error {
// Validate data
if user, ok := ctx.Data.(*User); ok {
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Add timestamps
user.CreatedAt = time.Now()
}
return nil
})
```
**Available Hook Types**:
* `BeforeRead`, `AfterRead`
* `BeforeCreate`, `AfterCreate`
* `BeforeUpdate`, `AfterUpdate`
* `BeforeDelete`, `AfterDelete`
**HookContext** provides:
* `Context`: Request context
* `Handler`: Access to handler, database, and registry
* `Schema`, `Entity`, `TableName`: Request info
* `Model`: The registered model type
* `Options`: Parsed request options (filters, sorting, etc.)
* `ID`: Record ID (for single-record operations)
* `Data`: Request data (for create/update)
* `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
## Cursor Pagination
RestHeadSpec supports efficient cursor-based pagination for large datasets:
```http
GET /public/posts HTTP/1.1
X-Sort: -created_at,+id
X-Limit: 50
X-Cursor-Forward: <cursor_token>
```
**How it works**:
1. First request returns results + cursor token in response
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
3. Cursor maintains consistent ordering even with data changes
4. Supports complex multi-column sorting
**Benefits over offset pagination**:
* Consistent results when data changes
* Better performance for large offsets
* Prevents "skipped" or duplicate records
* Works with complex sort expressions
**Example with hooks**:
```go
// Enable cursor pagination in a hook
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// For large tables, enforce cursor pagination
if ctx.Entity == "posts" && ctx.Options.Offset != nil && *ctx.Options.Offset > 1000 {
return fmt.Errorf("use cursor pagination for large offsets")
}
return nil
})
```
## Response Formats
RestHeadSpec supports multiple response formats:
**1. Simple Format** (`X-SimpleApi: true`):
```json
[
{ "id": 1, "name": "John" },
{ "id": 2, "name": "Jane" }
]
```
**2. Detail Format** (`X-DetailApi: true`, default):
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"limit": 50,
"offset": 0
}
}
```
**3. Syncfusion Format** (`X-Syncfusion: true`):
```json
{
"result": [...],
"count": 100
}
```
## Single Record as Object (Default Behavior)
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses.
**Default behavior (enabled)**:
```http
GET /public/users/123
```
```json
{
"success": true,
"data": { "id": 123, "name": "John", "email": "john@example.com" }
}
```
**To disable** (force arrays):
```http
GET /public/users/123
X-Single-Record-As-Object: false
```
```json
{
"success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
}
```
**How it works**:
* When a query returns exactly **one record**, it's returned as an object
* When a query returns **multiple records**, they're returned as an array
* Set `X-Single-Record-As-Object: false` to always receive arrays
* Works with all response formats (simple, detail, syncfusion)
* Applies to both read operations and create/update returning clauses
## CORS & OPTIONS Support
RestHeadSpec includes comprehensive CORS support for cross-origin requests:
**OPTIONS Method**:
```http
OPTIONS /public/users HTTP/1.1
```
Returns metadata with appropriate CORS headers:
```http
Access-Control-Allow-Origin: *
Access-Control-Allow-Methods: GET, POST, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
Access-Control-Max-Age: 86400
Access-Control-Allow-Credentials: true
```
**Key Features**:
* OPTIONS returns model metadata (same as GET metadata endpoint)
* All HTTP methods include CORS headers automatically
* OPTIONS requests don't require authentication (CORS preflight)
* Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
* 24-hour max age to reduce preflight requests
**Configuration**:
```go
import "github.com/bitechdev/ResolveSpec/pkg/common"
// Get default CORS config
corsConfig := common.DefaultCORSConfig()
// Customize if needed
corsConfig.AllowedOrigins = []string{"https://example.com"}
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
```
## Advanced Features
### Base64 Encoding
For complex header values, use base64 encoding:
```http
GET /public/users HTTP/1.1
X-Select-Fields-Base64: aWQsbmFtZSxlbWFpbA==
```
### AND/OR Logic
Combine multiple filters with AND/OR logic:
```http
GET /public/users HTTP/1.1
X-FieldFilter-Status: active
X-SearchOp-Gte-Age: 18
X-Filter-Logic: AND
```
### Complex Preloading
Load nested relationships:
```http
GET /public/users HTTP/1.1
X-Preload: posts:id,title,comments:id,text,author:name
```
## Model Registration
```go
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
}
// Schema.Table format
handler.Registry.RegisterModel("public.users", &User{})
```
## Complete Example
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}
func main() {
// Connect to database
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
if err != nil {
log.Fatal(err)
}
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register models
handler.Registry.RegisterModel("public.users", &User{})
// Add hooks
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
log.Printf("Reading %s", ctx.Entity)
return nil
})
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server
log.Println("Server starting on :8080")
log.Fatal(http.ListenAndServe(":8080", router))
}
```
## Testing
RestHeadSpec is designed for testability:
```go
import (
"net/http/httptest"
"testing"
)
func TestUserRead(t *testing.T) {
handler := restheadspec.NewHandlerWithGORM(testDB)
handler.Registry.RegisterModel("public.users", &User{})
req := httptest.NewRequest("GET", "/public/users", nil)
req.Header.Set("X-Select-Fields", "id,name")
req.Header.Set("X-Limit", "10")
rec := httptest.NewRecorder()
// Test your handler...
}
```
## See Also
* [HEADERS.md](HEADERS.md) - Complete header reference
* [Main README](../../README.md) - ResolveSpec overview
* [ResolveSpec Package](../resolvespec/README.md) - Body-based API
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
## License
This package is part of ResolveSpec and is licensed under the MIT License.

View File

@@ -1,4 +1,4 @@
package cache
package restheadspec
import (
"context"
@@ -7,56 +7,42 @@ import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
// expandOptionKey represents expand options for cache key
type expandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// queryCacheKey represents the components used to build a cache key for query total count
type queryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Expand []expandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
// cachedTotal represents a cached total count
type cachedTotal struct {
Total int `json:"total"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
key := queryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
@@ -69,11 +55,11 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
key.Expand = make([]expandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
expKey := expandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
@@ -83,7 +69,6 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
@@ -104,24 +89,38 @@ func hashString(s string) string {
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func getQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
// buildCacheTags creates cache tags from schema and table name
func buildCacheTags(schema, tableName string) []string {
return []string{
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
}
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// setQueryTotalCache stores a query total in the cache with schema and table tags
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
c := cache.GetDefaultCache()
cacheData := cachedTotal{Total: total}
tags := buildCacheTags(schema, tableName)
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
}
// invalidateCacheForTags removes all cached items matching the specified tags
func invalidateCacheForTags(ctx context.Context, tags []string) error {
c := cache.GetDefaultCache()
// Invalidate for each tag
for _, tag := range tags {
if err := c.DeleteByTag(ctx, tag); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,193 @@
package restheadspec
import (
"encoding/json"
"net/http"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test that normalizeResultArray returns empty array when no records found without ID
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
handler := &Handler{}
tests := []struct {
name string
input interface{}
shouldBeEmptyArr bool
}{
{
name: "nil should return empty array",
input: nil,
shouldBeEmptyArr: true,
},
{
name: "empty slice should return empty array",
input: []*EmptyTestModel{},
shouldBeEmptyArr: true,
},
{
name: "single element should return the element",
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
shouldBeEmptyArr: false,
},
{
name: "multiple elements should return the slice",
input: []*EmptyTestModel{
{ID: 1, Name: "test1"},
{ID: 2, Name: "test2"},
},
shouldBeEmptyArr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.normalizeResultArray(tt.input)
// For cases that should return empty array
if tt.shouldBeEmptyArr {
emptyArr, ok := result.([]interface{})
if !ok {
t.Errorf("Expected empty array []interface{}{}, got %T: %v", result, result)
return
}
if len(emptyArr) != 0 {
t.Errorf("Expected empty array with length 0, got length %d", len(emptyArr))
}
// Verify it serializes to [] and not null
jsonBytes, err := json.Marshal(result)
if err != nil {
t.Errorf("Failed to marshal result: %v", err)
return
}
if string(jsonBytes) != "[]" {
t.Errorf("Expected JSON '[]', got '%s'", string(jsonBytes))
}
}
})
}
}
// Test that sendFormattedResponse adds X-No-Data-Found header
func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
handler := &Handler{}
// Mock ResponseWriter
mockWriter := &MockTestResponseWriter{
headers: make(map[string]string),
}
metadata := &common.Metadata{
Total: 0,
Count: 0,
Filtered: 0,
Limit: 10,
Offset: 0,
}
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{},
}
// Test with empty data
emptyData := []interface{}{}
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options)
// Check if X-No-Data-Found header was set
if mockWriter.headers["X-No-Data-Found"] != "true" {
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
}
// Verify the body is an empty array
if mockWriter.body == nil {
t.Error("Expected body to be set, got nil")
} else {
bodyBytes, err := json.Marshal(mockWriter.body)
if err != nil {
t.Errorf("Failed to marshal body: %v", err)
}
// The body should be wrapped in a Response object with "data" field
bodyStr := string(bodyBytes)
if !strings.Contains(bodyStr, `"data":[]`) && !strings.Contains(bodyStr, `"result":[]`) {
t.Errorf("Expected body to contain empty array, got: %s", bodyStr)
}
}
}
// Test that sendResponseWithOptions adds X-No-Data-Found header
func TestSendResponseWithOptions_NoDataFoundHeader(t *testing.T) {
handler := &Handler{}
// Mock ResponseWriter
mockWriter := &MockTestResponseWriter{
headers: make(map[string]string),
}
metadata := &common.Metadata{}
options := &ExtendedRequestOptions{}
// Test with nil data
handler.sendResponseWithOptions(mockWriter, nil, metadata, options)
// Check if X-No-Data-Found header was set
if mockWriter.headers["X-No-Data-Found"] != "true" {
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
}
// Check status code is 200
if mockWriter.statusCode != 200 {
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
}
// Verify the body is an empty array
if mockWriter.body == nil {
t.Error("Expected body to be set, got nil")
} else {
bodyBytes, err := json.Marshal(mockWriter.body)
if err != nil {
t.Errorf("Failed to marshal body: %v", err)
}
bodyStr := string(bodyBytes)
if bodyStr != "[]" {
t.Errorf("Expected body to be '[]', got: %s", bodyStr)
}
}
}
// MockTestResponseWriter for testing
type MockTestResponseWriter struct {
headers map[string]string
statusCode int
body interface{}
}
func (m *MockTestResponseWriter) SetHeader(key, value string) {
m.headers[key] = value
}
func (m *MockTestResponseWriter) WriteHeader(statusCode int) {
m.statusCode = statusCode
}
func (m *MockTestResponseWriter) Write(data []byte) (int, error) {
return len(data), nil
}
func (m *MockTestResponseWriter) WriteJSON(data interface{}) error {
m.body = data
return nil
}
func (m *MockTestResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return nil
}
// EmptyTestModel for testing
type EmptyTestModel struct {
ID int64 `json:"id"`
Name string `json:"name"`
}

View File

@@ -482,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
// First add table prefixes to unqualified columns (but skip columns inside function calls)
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
@@ -492,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
@@ -529,7 +532,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
var total int
if !options.SkipCount {
// Try to get from cache first (unless SkipCache is true)
var cachedTotal *cache.CachedTotal
var cachedTotalData *cachedTotal
var cacheKey string
if !options.SkipCache {
@@ -543,7 +546,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
}
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
cacheKeyHash := buildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
@@ -554,22 +557,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
options.CursorForward,
options.CursorBackward,
)
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
cacheKey = getQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
cachedTotal = &cache.CachedTotal{}
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
cachedTotalData = &cachedTotal{}
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotalData)
if err == nil {
total = cachedTotal.Total
total = cachedTotalData.Total
logger.Debug("Total records (from cache): %d", total)
} else {
logger.Debug("Cache miss for query total")
cachedTotal = nil
cachedTotalData = nil
}
}
// If not in cache or cache skip, execute count query
if cachedTotal == nil {
if cachedTotalData == nil {
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
@@ -579,11 +582,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache (if caching is enabled)
// Store in cache with schema and table tags (if caching is enabled)
if !options.SkipCache && cacheKey != "" {
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := &cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
@@ -662,6 +664,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
return
}
// Check if a specific ID was requested but no record was found
resultCount := reflection.Len(modelPtr)
if id != "" && resultCount == 0 {
logger.Warn("Record not found for ID: %s", id)
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil)
return
}
limit := 0
if options.Limit != nil {
limit = *options.Limit
@@ -676,7 +686,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
metadata := &common.Metadata{
Total: int64(total),
Count: int64(reflection.Len(modelPtr)),
Count: int64(resultCount),
Filtered: int64(total),
Limit: limit,
Offset: offset,
@@ -851,7 +861,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
if len(preload.Where) > 0 {
// Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
// First add table prefixes to unqualified columns
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
// Then sanitize and allow preload table prefixes
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
@@ -1141,6 +1154,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
}
logger.Info("Successfully created %d record(s)", len(mergedResults))
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponseWithOptions(w, responseData, nil, &options)
}
@@ -1312,6 +1330,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
}
logger.Info("Successfully updated record with ID: %v", targetID)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponseWithOptions(w, mergedData, nil, &options)
}
@@ -1380,6 +1403,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", deletedCount)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
return
@@ -1448,6 +1476,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", deletedCount)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
return
@@ -1502,6 +1535,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
logger.Info("Successfully deleted %d records", deletedCount)
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
return
@@ -1603,6 +1641,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
}
// Return the deleted record data
// Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
}
h.sendResponse(w, recordToDelete, nil)
}
@@ -2099,14 +2142,30 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
// sendResponseWithOptions sends a response with optional formatting
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
w.SetHeader("Content-Type", "application/json")
// Handle nil data - convert to empty array
if data == nil {
data = []interface{}{}
}
// Calculate data length after nil conversion
dataLen := reflection.Len(data)
// Add X-No-Data-Found header when no records were found
if dataLen == 0 {
w.SetHeader("X-No-Data-Found", "true")
}
w.WriteHeader(http.StatusOK)
// Normalize single-record arrays to objects if requested
if options != nil && options.SingleRecordAsObject {
data = h.normalizeResultArray(data)
}
// Return data as-is without wrapping in common.Response
w.SetHeader("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
@@ -2116,7 +2175,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil {
return nil
return []interface{}{}
}
// Use reflection to check if data is a slice or array
@@ -2125,18 +2184,50 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
dataValue = dataValue.Elem()
}
// Check if it's a slice or array with exactly one element
if (dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array) && dataValue.Len() == 1 {
// Return the single element
return dataValue.Index(0).Interface()
// Check if it's a slice or array
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
if dataValue.Len() == 1 {
// Return the single element
return dataValue.Index(0).Interface()
} else if dataValue.Len() == 0 {
// Keep empty array as empty array, don't convert to empty object
return []interface{}{}
}
}
if dataValue.Kind() == reflect.String {
str := dataValue.String()
if str == "" || str == "null" {
return []interface{}{}
}
}
return data
}
// sendFormattedResponse sends response with formatting options
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
// Normalize single-record arrays to objects if requested
httpStatus := http.StatusOK
// Handle nil data - convert to empty array
if data == nil {
data = []interface{}{}
}
// Calculate data length after nil conversion
// Note: This is done BEFORE normalization because X-No-Data-Found indicates
// whether data was found in the database, not the final response format
dataLen := reflection.Len(data)
// Add X-No-Data-Found header when no records were found
if dataLen == 0 {
w.SetHeader("X-No-Data-Found", "true")
}
// Apply normalization after header is set
// normalizeResultArray may convert single-element arrays to objects,
// but the X-No-Data-Found header reflects the original query result
if options.SingleRecordAsObject {
data = h.normalizeResultArray(data)
}
@@ -2155,7 +2246,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
switch options.ResponseFormat {
case "simple":
// Simple format: just return the data array
w.WriteHeader(http.StatusOK)
w.WriteHeader(httpStatus)
if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
@@ -2167,7 +2258,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
if metadata != nil {
response["count"] = metadata.Total
}
w.WriteHeader(http.StatusOK)
w.WriteHeader(httpStatus)
if err := w.WriteJSON(response); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
@@ -2178,7 +2269,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
Data: data,
Metadata: metadata,
}
w.WriteHeader(http.StatusOK)
w.WriteHeader(httpStatus)
if err := w.WriteJSON(response); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}

View File

@@ -935,7 +935,16 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Add WHERE clause if SQL conditions specified
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
whereConditions = append(whereConditions, xfile.SqlAnd...)
// Process each SQL condition: add table prefixes and sanitize
for _, sqlCond := range xfile.SqlAnd {
// First add table prefixes to unqualified columns
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
// Then sanitize the condition
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond)
}
}
}
if len(whereConditions) > 0 {
preloadOpt.Where = strings.Join(whereConditions, " AND ")

View File

@@ -1,3 +1,4 @@
//go:build integration
// +build integration
package restheadspec
@@ -21,12 +22,12 @@ import (
// Test models
type TestUser struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"`
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"`
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
}
@@ -35,13 +36,13 @@ func (TestUser) TableName() string {
}
type TestPost struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
}
@@ -54,7 +55,7 @@ type TestComment struct {
PostID uint `gorm:"not null" json:"post_id"`
Content string `gorm:"not null" json:"content"`
CreatedAt time.Time `json:"created_at"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
}
func (TestComment) TableName() string {
@@ -401,7 +402,7 @@ func TestIntegration_GetMetadata(t *testing.T) {
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
@@ -492,7 +493,7 @@ func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
t.Errorf("Expected status 200, got %d", w.Code)
}

View File

@@ -296,7 +296,7 @@ func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName
}
func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (reflect.Value, error) {
defer logger.CatchPanic("ApplyColumnSecurity")
defer logger.CatchPanic("ApplyColumnSecurity")()
if m.ColumnSecurity == nil {
return records, fmt.Errorf("security not initialized")
@@ -437,7 +437,7 @@ func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema
}
func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
defer logger.CatchPanic("GetRowSecurityTemplate")
defer logger.CatchPanic("GetRowSecurityTemplate")()
if m.RowSecurity == nil {
return RowSecurity{}, fmt.Errorf("security not initialized")

View File

@@ -1,233 +1,314 @@
# Server Package
Graceful HTTP server with request draining and shutdown coordination.
Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support.
## Features
**Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently
**Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining
**Automatic Request Rejection** - New requests get 503 during shutdown
**Health & Readiness Endpoints** - Kubernetes-ready health checks
**Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics)
**Comprehensive TLS Support**:
- Certificate files (production)
- Self-signed certificates (development/testing)
- Let's Encrypt / AutoTLS (automatic certificate management)
**GZIP Compression** - Optional response compression
**Panic Recovery** - Automatic panic recovery middleware
**Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts
## Quick Start
### Single Server
```go
import "github.com/bitechdev/ResolveSpec/pkg/server"
// Create server
srv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: router,
// Create server manager
mgr := server.NewManager()
// Add server
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: myRouter,
GZIP: true,
})
// Start server (blocks until shutdown signal)
if err := srv.ListenAndServe(); err != nil {
// Start and wait for shutdown signal
if err := mgr.ServeWithGracefulShutdown(); err != nil {
log.Fatal(err)
}
```
## Features
### Multiple Servers
✅ Graceful shutdown on SIGINT/SIGTERM
✅ Request draining (waits for in-flight requests)
✅ Automatic request rejection during shutdown
✅ Health and readiness endpoints
✅ Shutdown callbacks for cleanup
✅ Configurable timeouts
```go
mgr := server.NewManager()
// Public API
mgr.Add(server.Config{
Name: "public-api",
Port: 8080,
Handler: publicRouter,
})
// Admin API
mgr.Add(server.Config{
Name: "admin-api",
Port: 8081,
Handler: adminRouter,
})
// Start all and wait
mgr.ServeWithGracefulShutdown()
```
## HTTPS/TLS Configuration
### Option 1: Certificate Files (Production)
```go
mgr.Add(server.Config{
Name: "https-server",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
SSLCert: "/etc/ssl/certs/server.crt",
SSLKey: "/etc/ssl/private/server.key",
})
```
### Option 2: Self-Signed Certificate (Development)
```go
mgr.Add(server.Config{
Name: "dev-server",
Host: "localhost",
Port: 8443,
Handler: handler,
SelfSignedSSL: true, // Auto-generates certificate
})
```
### Option 3: Let's Encrypt / AutoTLS (Production)
```go
mgr.Add(server.Config{
Name: "prod-server",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache", // Certificate cache directory
})
```
## Configuration
```go
config := server.Config{
// Server address
Addr: ":8080",
server.Config{
// Basic configuration
Name: "my-server", // Server name (required)
Host: "0.0.0.0", // Bind address
Port: 8080, // Port (required)
Handler: myRouter, // HTTP handler (required)
Description: "My API server", // Optional description
// HTTP handler
Handler: myRouter,
// Features
GZIP: true, // Enable GZIP compression
// Maximum time for graceful shutdown (default: 30s)
ShutdownTimeout: 30 * time.Second,
// TLS/HTTPS (choose one option)
SSLCert: "/path/to/cert.pem", // Certificate file
SSLKey: "/path/to/key.pem", // Key file
SelfSignedSSL: false, // Auto-generate self-signed cert
AutoTLS: false, // Let's Encrypt
AutoTLSDomains: []string{}, // Domains for AutoTLS
AutoTLSEmail: "", // Email for Let's Encrypt
AutoTLSCacheDir: "./certs-cache", // Cert cache directory
// Time to wait for in-flight requests (default: 25s)
DrainTimeout: 25 * time.Second,
// Request read timeout (default: 10s)
ReadTimeout: 10 * time.Second,
// Response write timeout (default: 10s)
WriteTimeout: 10 * time.Second,
// Idle connection timeout (default: 120s)
IdleTimeout: 120 * time.Second,
// Timeouts
ShutdownTimeout: 30 * time.Second, // Max shutdown time
DrainTimeout: 25 * time.Second, // Request drain timeout
ReadTimeout: 15 * time.Second, // Request read timeout
WriteTimeout: 15 * time.Second, // Response write timeout
IdleTimeout: 60 * time.Second, // Idle connection timeout
}
srv := server.NewGracefulServer(config)
```
## Shutdown Behavior
## Graceful Shutdown
**Signal received (SIGINT/SIGTERM):**
### Automatic (Recommended)
1. **Mark as shutting down** - New requests get 503
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
3. **Shutdown server** - Close listeners and connections
4. **Execute callbacks** - Run registered cleanup functions
```go
mgr := server.NewManager()
// Add servers...
// This blocks until SIGINT/SIGTERM
mgr.ServeWithGracefulShutdown()
```
Time Event
─────────────────────────────────────────
0s Signal received: SIGTERM
├─ Mark as shutting down
├─ Reject new requests (503)
└─ Start draining...
1s In-flight: 50 requests
2s In-flight: 32 requests
3s In-flight: 12 requests
4s In-flight: 3 requests
5s In-flight: 0 requests ✓
└─ All requests drained
### Manual Control
5s Execute shutdown callbacks
6s Shutdown complete
```go
mgr := server.NewManager()
// Add and start servers
mgr.StartAll()
// Later... stop gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mgr.StopAllWithContext(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
```
### Shutdown Callbacks
Register cleanup functions to run during shutdown:
```go
// Close database
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Closing database...")
return db.Close()
})
// Flush metrics
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Flushing metrics...")
return metrics.Flush(ctx)
})
// Close cache
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Closing cache...")
return cache.Close()
})
```
## Health Checks
### Health Endpoint
Returns 200 when healthy, 503 when shutting down:
### Adding Health Endpoints
```go
router.HandleFunc("/health", srv.HealthCheckHandler())
instance, _ := mgr.Add(server.Config{
Name: "api-server",
Port: 8080,
Handler: router,
})
// Add health endpoints to your router
router.HandleFunc("/health", instance.HealthCheckHandler())
router.HandleFunc("/ready", instance.ReadinessHandler())
```
**Response (healthy):**
### Health Endpoint
Returns server health status:
**Healthy (200 OK):**
```json
{"status":"healthy"}
```
**Response (shutting down):**
**Shutting Down (503 Service Unavailable):**
```json
{"status":"shutting_down"}
```
### Readiness Endpoint
Includes in-flight request count:
Returns readiness with in-flight request count:
```go
router.HandleFunc("/ready", srv.ReadinessHandler())
```
**Response:**
**Ready (200 OK):**
```json
{"ready":true,"in_flight_requests":12}
```
**During shutdown:**
**Not Ready (503 Service Unavailable):**
```json
{"ready":false,"reason":"shutting_down"}
```
## Shutdown Callbacks
## Shutdown Behavior
Register cleanup functions to run during shutdown:
When a shutdown signal (SIGINT/SIGTERM) is received:
```go
// Close database
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Closing database connection...")
return db.Close()
})
1. **Mark as shutting down** → New requests get 503
2. **Execute callbacks** → Run cleanup functions
3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests
4. **Shutdown servers** → Close listeners and connections
// Flush metrics
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Flushing metrics...")
return metricsProvider.Flush(ctx)
})
```
Time Event
─────────────────────────────────────────
0s Signal received: SIGTERM
├─ Mark servers as shutting down
├─ Reject new requests (503)
└─ Execute shutdown callbacks
// Close cache
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Closing cache...")
return cache.Close()
})
1s Callbacks complete
└─ Start draining requests...
2s In-flight: 50 requests
3s In-flight: 32 requests
4s In-flight: 12 requests
5s In-flight: 3 requests
6s In-flight: 0 requests ✓
└─ All requests drained
6s Shutdown servers
7s All servers stopped ✓
```
## Complete Example
## Server Management
### Get Server Instance
```go
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/bitechdev/ResolveSpec/pkg/server"
"github.com/gorilla/mux"
)
func main() {
// Initialize metrics
metricsProvider := metrics.NewPrometheusProvider()
metrics.SetProvider(metricsProvider)
// Create router
router := mux.NewRouter()
// Apply middleware
rateLimiter := middleware.NewRateLimiter(100, 20)
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
sanitizer := middleware.DefaultSanitizer()
router.Use(rateLimiter.Middleware)
router.Use(sizeLimiter.Middleware)
router.Use(sanitizer.Middleware)
router.Use(metricsProvider.Middleware)
// API routes
router.HandleFunc("/api/data", dataHandler)
// Create graceful server
srv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: router,
ShutdownTimeout: 30 * time.Second,
DrainTimeout: 25 * time.Second,
})
// Health checks
router.HandleFunc("/health", srv.HealthCheckHandler())
router.HandleFunc("/ready", srv.ReadinessHandler())
// Metrics endpoint
router.Handle("/metrics", metricsProvider.Handler())
// Register shutdown callbacks
server.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Flushing metrics...")
return nil
})
server.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Closing database...")
// return db.Close()
return nil
})
// Start server (blocks until shutdown)
log.Printf("Starting server on :8080")
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}
// Wait for shutdown to complete
srv.Wait()
log.Println("Server stopped")
instance, err := mgr.Get("api-server")
if err != nil {
log.Fatal(err)
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
// Your handler logic
time.Sleep(100 * time.Millisecond) // Simulate work
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"success"}`))
// Check status
fmt.Printf("Address: %s\n", instance.Addr())
fmt.Printf("Name: %s\n", instance.Name())
fmt.Printf("In-flight: %d\n", instance.InFlightRequests())
fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown())
```
### List All Servers
```go
instances := mgr.List()
for _, instance := range instances {
fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr())
}
```
### Remove Server
```go
// Stop and remove a server
if err := mgr.Remove("api-server"); err != nil {
log.Printf("Error removing server: %v", err)
}
```
### Restart All Servers
```go
// Gracefully restart all servers
if err := mgr.RestartAll(); err != nil {
log.Printf("Error restarting: %v", err)
}
```
@@ -250,23 +331,21 @@ spec:
ports:
- containerPort: 8080
# Liveness probe - is app running?
# Liveness probe
livenessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
timeoutSeconds: 5
# Readiness probe - can app handle traffic?
# Readiness probe
readinessProbe:
httpGet:
path: /ready
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
timeoutSeconds: 3
# Graceful shutdown
lifecycle:
@@ -274,26 +353,12 @@ spec:
exec:
command: ["/bin/sh", "-c", "sleep 5"]
# Environment
env:
- name: SHUTDOWN_TIMEOUT
value: "30"
```
### Service
```yaml
apiVersion: v1
kind: Service
metadata:
name: myapp
spec:
selector:
app: myapp
ports:
- port: 80
targetPort: 8080
type: LoadBalancer
# Allow time for graceful shutdown
terminationGracePeriodSeconds: 35
```
## Docker Compose
@@ -312,8 +377,70 @@ services:
interval: 10s
timeout: 5s
retries: 3
start_period: 10s
stop_grace_period: 35s # Slightly longer than shutdown timeout
stop_grace_period: 35s
```
## Complete Example
```go
package main
import (
"context"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
func main() {
// Create server manager
mgr := server.NewManager()
// Register shutdown callbacks
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
log.Println("Cleanup: Closing database...")
// return db.Close()
return nil
})
// Create router
router := http.NewServeMux()
router.HandleFunc("/api/data", dataHandler)
// Add server
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "0.0.0.0",
Port: 8080,
Handler: router,
GZIP: true,
ShutdownTimeout: 30 * time.Second,
DrainTimeout: 25 * time.Second,
})
if err != nil {
log.Fatal(err)
}
// Add health endpoints
router.HandleFunc("/health", instance.HealthCheckHandler())
router.HandleFunc("/ready", instance.ReadinessHandler())
// Start and wait for shutdown
log.Println("Starting server on :8080")
if err := mgr.ServeWithGracefulShutdown(); err != nil {
log.Printf("Server stopped: %v", err)
}
log.Println("Server shutdown complete")
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond) // Simulate work
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"success"}`))
}
```
## Testing Graceful Shutdown
@@ -330,7 +457,7 @@ SERVER_PID=$!
# Wait for server to start
sleep 2
# Send some requests
# Send requests
for i in {1..10}; do
curl http://localhost:8080/api/data &
done
@@ -341,7 +468,7 @@ sleep 1
# Send shutdown signal
kill -TERM $SERVER_PID
# Try to send more requests (should get 503)
# Try more requests (should get 503)
curl -v http://localhost:8080/api/data
# Wait for server to stop
@@ -349,101 +476,13 @@ wait $SERVER_PID
echo "Server stopped gracefully"
```
### Expected Output
```
Starting server on :8080
Received signal: terminated, initiating graceful shutdown
Starting graceful shutdown...
Waiting for 8 in-flight requests to complete...
Waiting for 4 in-flight requests to complete...
Waiting for 1 in-flight requests to complete...
All requests drained in 2.3s
Cleanup: Flushing metrics...
Cleanup: Closing database...
Shutting down HTTP server...
Graceful shutdown complete
Server stopped
```
## Monitoring In-Flight Requests
```go
// Get current in-flight count
count := srv.InFlightRequests()
fmt.Printf("In-flight requests: %d\n", count)
// Check if shutting down
if srv.IsShuttingDown() {
fmt.Println("Server is shutting down")
}
```
## Advanced Usage
### Custom Shutdown Logic
```go
// Implement custom shutdown
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
log.Println("Shutdown signal received")
// Custom pre-shutdown logic
log.Println("Running custom cleanup...")
// Shutdown with callbacks
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
log.Printf("Shutdown error: %v", err)
}
}()
// Start server
srv.server.ListenAndServe()
```
### Multiple Servers
```go
// HTTP server
httpSrv := server.NewGracefulServer(server.Config{
Addr: ":8080",
Handler: httpRouter,
})
// HTTPS server
httpsSrv := server.NewGracefulServer(server.Config{
Addr: ":8443",
Handler: httpsRouter,
})
// Start both
go httpSrv.ListenAndServe()
go httpsSrv.ListenAndServe()
// Shutdown both on signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
<-sigChan
ctx := context.Background()
httpSrv.Shutdown(ctx)
httpsSrv.Shutdown(ctx)
```
## Best Practices
1. **Set appropriate timeouts**
- `DrainTimeout` < `ShutdownTimeout`
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
2. **Register cleanup callbacks** for:
2. **Use shutdown callbacks** for:
- Database connections
- Message queues
- Metrics flushing
@@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx)
- Set `preStop` hook in Kubernetes (5-10s delay)
- Allows load balancer to deregister before shutdown
5. **Monitoring**
5. **HTTPS in production**
- Use AutoTLS for public-facing services
- Use certificate files for enterprise PKI
- Use self-signed only for development/testing
6. **Monitoring**
- Track in-flight requests in metrics
- Alert on slow drains
- Monitor shutdown duration
@@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx)
```go
// Increase drain timeout
config.DrainTimeout = 60 * time.Second
config.ShutdownTimeout = 65 * time.Second
```
### Requests Still Timing Out
### Requests Timing Out
```go
// Increase write timeout
config.WriteTimeout = 30 * time.Second
```
### Force Shutdown Not Working
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
### Debugging Shutdown
### Certificate Issues
```go
// Verify certificate files exist and are readable
if _, err := os.Stat(config.SSLCert); err != nil {
log.Fatalf("Certificate not found: %v", err)
}
// For AutoTLS, ensure:
// - Port 443 is accessible
// - Domains resolve to server IP
// - Cache directory is writable
```
### Debug Logging
```go
// Enable debug logging
import "github.com/bitechdev/ResolveSpec/pkg/logger"
// Enable debug logging
logger.SetLevel("debug")
```
## API Reference
### Manager Methods
- `NewManager()` - Create new server manager
- `Add(cfg Config)` - Register server instance
- `Get(name string)` - Get server by name
- `Remove(name string)` - Stop and remove server
- `StartAll()` - Start all registered servers
- `StopAll()` - Stop all servers gracefully
- `StopAllWithContext(ctx)` - Stop with timeout
- `RestartAll()` - Restart all servers
- `List()` - Get all server instances
- `ServeWithGracefulShutdown()` - Start and block until shutdown
- `RegisterShutdownCallback(cb)` - Register cleanup function
### Instance Methods
- `Start()` - Start the server
- `Stop(ctx)` - Stop gracefully
- `Addr()` - Get server address
- `Name()` - Get server name
- `HealthCheckHandler()` - Get health handler
- `ReadinessHandler()` - Get readiness handler
- `InFlightRequests()` - Get in-flight count
- `IsShuttingDown()` - Check shutdown status
- `Wait()` - Block until shutdown complete

294
pkg/server/example_test.go Normal file
View File

@@ -0,0 +1,294 @@
package server_test
import (
"context"
"fmt"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
// ExampleManager_basic demonstrates basic server manager usage
func ExampleManager_basic() {
// Create a server manager
mgr := server.NewManager()
// Define a simple handler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "Hello from server!")
})
// Add an HTTP server
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
GZIP: true, // Enable GZIP compression
})
if err != nil {
panic(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Server is now running...
// When done, stop gracefully
if err := mgr.StopAll(); err != nil {
panic(err)
}
}
// ExampleManager_https demonstrates HTTPS configurations
func ExampleManager_https() {
mgr := server.NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Secure connection!")
})
// Option 1: Use certificate files
_, err := mgr.Add(server.Config{
Name: "https-server-files",
Host: "localhost",
Port: 8443,
Handler: handler,
SSLCert: "/path/to/cert.pem",
SSLKey: "/path/to/key.pem",
})
if err != nil {
panic(err)
}
// Option 2: Self-signed certificate (for development)
_, err = mgr.Add(server.Config{
Name: "https-server-self-signed",
Host: "localhost",
Port: 8444,
Handler: handler,
SelfSignedSSL: true,
})
if err != nil {
panic(err)
}
// Option 3: Let's Encrypt / AutoTLS (for production)
_, err = mgr.Add(server.Config{
Name: "https-server-letsencrypt",
Host: "0.0.0.0",
Port: 443,
Handler: handler,
AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache",
})
if err != nil {
panic(err)
}
// Start all servers
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Cleanup
mgr.StopAll()
}
// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks
func ExampleManager_gracefulShutdown() {
mgr := server.NewManager()
// Register shutdown callbacks for cleanup tasks
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
fmt.Println("Closing database connections...")
// Close your database here
return nil
})
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
fmt.Println("Flushing metrics...")
// Flush metrics here
return nil
})
// Add server with custom timeouts
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate some work
time.Sleep(100 * time.Millisecond)
fmt.Fprintln(w, "Done!")
})
_, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
ShutdownTimeout: 30 * time.Second, // Max time for shutdown
DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
})
if err != nil {
panic(err)
}
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
// This will automatically handle graceful shutdown with callbacks
if err := mgr.ServeWithGracefulShutdown(); err != nil {
fmt.Printf("Shutdown completed: %v\n", err)
}
}
// ExampleManager_healthChecks demonstrates health and readiness endpoints
func ExampleManager_healthChecks() {
mgr := server.NewManager()
// Create a router with health endpoints
mux := http.NewServeMux()
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Data endpoint")
})
// Add server
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: mux,
})
if err != nil {
panic(err)
}
// Add health and readiness endpoints
mux.HandleFunc("/health", instance.HealthCheckHandler())
mux.HandleFunc("/ready", instance.ReadinessHandler())
// Start the server
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Health check returns:
// - 200 OK with {"status":"healthy"} when healthy
// - 503 Service Unavailable with {"status":"shutting_down"} when shutting down
// Readiness check returns:
// - 200 OK with {"ready":true,"in_flight_requests":N} when ready
// - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down
// Cleanup
mgr.StopAll()
}
// ExampleManager_multipleServers demonstrates running multiple servers
func ExampleManager_multipleServers() {
mgr := server.NewManager()
// Public API server
publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Public API")
})
_, err := mgr.Add(server.Config{
Name: "public-api",
Host: "0.0.0.0",
Port: 8080,
Handler: publicHandler,
GZIP: true,
})
if err != nil {
panic(err)
}
// Admin API server (different port)
adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Admin API")
})
_, err = mgr.Add(server.Config{
Name: "admin-api",
Host: "localhost",
Port: 8081,
Handler: adminHandler,
})
if err != nil {
panic(err)
}
// Metrics server (internal only)
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Metrics data")
})
_, err = mgr.Add(server.Config{
Name: "metrics",
Host: "127.0.0.1",
Port: 9090,
Handler: metricsHandler,
})
if err != nil {
panic(err)
}
// Start all servers at once
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Get specific server instance
publicInstance, err := mgr.Get("public-api")
if err != nil {
panic(err)
}
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
// List all servers
instances := mgr.List()
fmt.Printf("Running %d servers\n", len(instances))
// Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil {
panic(err)
}
}
// ExampleManager_monitoring demonstrates monitoring server state
func ExampleManager_monitoring() {
mgr := server.NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond) // Simulate work
fmt.Fprintln(w, "Done")
})
instance, err := mgr.Add(server.Config{
Name: "api-server",
Host: "localhost",
Port: 8080,
Handler: handler,
})
if err != nil {
panic(err)
}
if err := mgr.StartAll(); err != nil {
panic(err)
}
// Check server status
fmt.Printf("Server address: %s\n", instance.Addr())
fmt.Printf("Server name: %s\n", instance.Name())
fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown())
fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests())
// Cleanup
mgr.StopAll()
// Wait for complete shutdown
instance.Wait()
}

137
pkg/server/interfaces.go Normal file
View File

@@ -0,0 +1,137 @@
package server
import (
"context"
"net/http"
"time"
)
// Config holds the configuration for a single web server instance.
type Config struct {
Name string
Host string
Port int
Description string
// Handler is the http.Handler (e.g., a router) to be served.
Handler http.Handler
// GZIP compression support
GZIP bool
// TLS/HTTPS configuration options (mutually exclusive)
// Option 1: Provide certificate and key files directly
SSLCert string
SSLKey string
// Option 2: Use self-signed certificate (for development/testing)
// Generates a self-signed certificate automatically if no SSLCert/SSLKey provided
SelfSignedSSL bool
// Option 3: Use Let's Encrypt / Certbot for automatic TLS
// AutoTLS enables automatic certificate management via Let's Encrypt
AutoTLS bool
// AutoTLSDomains specifies the domains for Let's Encrypt certificates
AutoTLSDomains []string
// AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache")
AutoTLSCacheDir string
// AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended)
AutoTLSEmail string
// Graceful shutdown configuration
// ShutdownTimeout is the maximum time to wait for graceful shutdown
// Default: 30 seconds
ShutdownTimeout time.Duration
// DrainTimeout is the time to wait for in-flight requests to complete
// before forcing shutdown. Default: 25 seconds
DrainTimeout time.Duration
// ReadTimeout is the maximum duration for reading the entire request
// Default: 15 seconds
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out writes of the response
// Default: 15 seconds
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the next request
// Default: 60 seconds
IdleTimeout time.Duration
}
// Instance defines the interface for a single server instance.
// It abstracts the underlying http.Server, allowing for easier management and testing.
type Instance interface {
// Start begins serving requests. This method should be non-blocking and
// run the server in a separate goroutine.
Start() error
// Stop gracefully shuts down the server without interrupting any active connections.
// It accepts a context to allow for a timeout.
Stop(ctx context.Context) error
// Addr returns the network address the server is listening on.
Addr() string
// Name returns the server instance name.
Name() string
// HealthCheckHandler returns a handler that responds to health checks.
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down.
HealthCheckHandler() http.HandlerFunc
// ReadinessHandler returns a handler for readiness checks.
// Includes in-flight request count.
ReadinessHandler() http.HandlerFunc
// InFlightRequests returns the current number of in-flight requests.
InFlightRequests() int64
// IsShuttingDown returns true if the server is shutting down.
IsShuttingDown() bool
// Wait blocks until shutdown is complete.
Wait()
}
// Manager defines the interface for a server manager.
// It is responsible for managing the lifecycle of multiple server instances.
type Manager interface {
// Add registers a new server instance based on the provided configuration.
// The server is not started until StartAll or Start is called on the instance.
Add(cfg Config) (Instance, error)
// Get returns a server instance by its name.
Get(name string) (Instance, error)
// Remove stops and removes a server instance by its name.
Remove(name string) error
// StartAll starts all registered server instances that are not already running.
StartAll() error
// StopAll gracefully shuts down all running server instances.
// Executes shutdown callbacks and drains in-flight requests.
StopAll() error
// StopAllWithContext gracefully shuts down all running server instances with a context.
StopAllWithContext(ctx context.Context) error
// RestartAll gracefully restarts all running server instances.
RestartAll() error
// List returns all registered server instances.
List() []Instance
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
// It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks.
ServeWithGracefulShutdown() error
// RegisterShutdownCallback registers a callback to be called during shutdown.
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
RegisterShutdownCallback(cb ShutdownCallback)
}
// ShutdownCallback is a function called during graceful shutdown.
type ShutdownCallback func(context.Context) error

601
pkg/server/manager.go Normal file
View File

@@ -0,0 +1,601 @@
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/klauspost/compress/gzhttp"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
)
// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type)
type gracefulServer struct {
server *http.Server
shutdownTimeout time.Duration
drainTimeout time.Duration
inFlightRequests atomic.Int64
isShuttingDown atomic.Bool
shutdownOnce sync.Once
shutdownComplete chan struct{}
}
// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if shutting down
if gs.isShuttingDown.Load() {
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
return
}
// Increment in-flight counter
gs.inFlightRequests.Add(1)
defer gs.inFlightRequests.Add(-1)
// Serve the request
next.ServeHTTP(w, r)
})
}
// shutdown performs graceful shutdown with request draining
func (gs *gracefulServer) shutdown(ctx context.Context) error {
var shutdownErr error
gs.shutdownOnce.Do(func() {
logger.Info("Starting graceful shutdown...")
// Mark as shutting down (new requests will be rejected)
gs.isShuttingDown.Store(true)
// Create context with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
defer cancel()
// Wait for in-flight requests to complete (with drain timeout)
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
defer drainCancel()
shutdownErr = gs.drainRequests(drainCtx)
if shutdownErr != nil {
logger.Error("Error draining requests: %v", shutdownErr)
}
// Shutdown the server
logger.Info("Shutting down HTTP server...")
if err := gs.server.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down server: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
logger.Info("Graceful shutdown complete")
close(gs.shutdownComplete)
})
return shutdownErr
}
// drainRequests waits for in-flight requests to complete
func (gs *gracefulServer) drainRequests(ctx context.Context) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
startTime := time.Now()
for {
inFlight := gs.inFlightRequests.Load()
if inFlight == 0 {
logger.Info("All requests drained in %v", time.Since(startTime))
return nil
}
select {
case <-ctx.Done():
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
case <-ticker.C:
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
}
}
}
// inFlightRequests returns the current number of in-flight requests
func (gs *gracefulServer) inFlightRequestsCount() int64 {
return gs.inFlightRequests.Load()
}
// isShutdown returns true if the server is shutting down
func (gs *gracefulServer) isShutdown() bool {
return gs.isShuttingDown.Load()
}
// wait blocks until shutdown is complete
func (gs *gracefulServer) wait() {
<-gs.shutdownComplete
}
// healthCheckHandler returns a handler that responds to health checks
func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.isShutdown() {
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"healthy"}`))
if err != nil {
logger.Warn("Failed to write health check response: %v", err)
}
}
}
// readinessHandler returns a handler for readiness checks
func (gs *gracefulServer) readinessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.isShutdown() {
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
inFlight := gs.inFlightRequestsCount()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
}
}
// serverManager manages a collection of server instances with graceful shutdown support.
type serverManager struct {
instances map[string]Instance
mu sync.RWMutex
shutdownCallbacks []ShutdownCallback
callbacksMu sync.Mutex
}
// NewManager creates a new server manager.
func NewManager() Manager {
return &serverManager{
instances: make(map[string]Instance),
shutdownCallbacks: make([]ShutdownCallback, 0),
}
}
// Add registers a new server instance.
func (sm *serverManager) Add(cfg Config) (Instance, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
if cfg.Name == "" {
return nil, fmt.Errorf("server name cannot be empty")
}
if _, exists := sm.instances[cfg.Name]; exists {
return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name)
}
instance, err := newInstance(cfg)
if err != nil {
return nil, err
}
sm.instances[cfg.Name] = instance
return instance, nil
}
// Get returns a server instance by its name.
func (sm *serverManager) Get(name string) (Instance, error) {
sm.mu.RLock()
defer sm.mu.RUnlock()
instance, exists := sm.instances[name]
if !exists {
return nil, fmt.Errorf("server with name '%s' not found", name)
}
return instance, nil
}
// Remove stops and removes a server instance by its name.
func (sm *serverManager) Remove(name string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
instance, exists := sm.instances[name]
if !exists {
return fmt.Errorf("server with name '%s' not found", name)
}
// Stop the server if it's running. Prefer the server's configured shutdownTimeout
// when available, and fall back to a sensible default.
timeout := 10 * time.Second
if si, ok := instance.(*serverInstance); ok && si.gracefulServer != nil && si.gracefulServer.shutdownTimeout > 0 {
timeout = si.gracefulServer.shutdownTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := instance.Stop(ctx); err != nil {
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err)
}
delete(sm.instances, name)
return nil
}
// StartAll starts all registered server instances.
func (sm *serverManager) StartAll() error {
sm.mu.RLock()
defer sm.mu.RUnlock()
var startErrors []error
for name, instance := range sm.instances {
if err := instance.Start(); err != nil {
startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err))
}
}
if len(startErrors) > 0 {
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
}
return nil
}
// StopAll gracefully shuts down all running server instances.
func (sm *serverManager) StopAll() error {
return sm.StopAllWithContext(context.Background())
}
// StopAllWithContext gracefully shuts down all running server instances with a context.
func (sm *serverManager) StopAllWithContext(ctx context.Context) error {
sm.mu.RLock()
instancesToStop := make([]Instance, 0, len(sm.instances))
for _, instance := range sm.instances {
instancesToStop = append(instancesToStop, instance)
}
sm.mu.RUnlock()
logger.Info("Shutting down all servers...")
// Execute shutdown callbacks first
sm.callbacksMu.Lock()
callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks))
copy(callbacks, sm.shutdownCallbacks)
sm.callbacksMu.Unlock()
if len(callbacks) > 0 {
logger.Info("Executing %d shutdown callbacks...", len(callbacks))
for i, cb := range callbacks {
if err := cb(ctx); err != nil {
logger.Error("Shutdown callback %d failed: %v", i+1, err)
}
}
}
// Stop all instances in parallel
var shutdownErrors []error
var wg sync.WaitGroup
var errorsMu sync.Mutex
for _, instance := range instancesToStop {
wg.Add(1)
go func(inst Instance) {
defer wg.Done()
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
if err := inst.Stop(shutdownCtx); err != nil {
errorsMu.Lock()
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err))
errorsMu.Unlock()
}
}(instance)
}
wg.Wait()
if len(shutdownErrors) > 0 {
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
}
logger.Info("All servers stopped gracefully.")
return nil
}
// RestartAll gracefully restarts all running server instances.
func (sm *serverManager) RestartAll() error {
logger.Info("Restarting all servers...")
if err := sm.StopAll(); err != nil {
return fmt.Errorf("failed to stop servers during restart: %w", err)
}
// Retry starting all servers with exponential backoff instead of a fixed sleep.
const (
maxAttempts = 5
initialBackoff = 100 * time.Millisecond
maxBackoff = 2 * time.Second
)
var lastErr error
backoff := initialBackoff
for attempt := 1; attempt <= maxAttempts; attempt++ {
if err := sm.StartAll(); err != nil {
lastErr = err
if attempt == maxAttempts {
break
}
logger.Warn("Attempt %d to start servers during restart failed: %v; retrying in %s", attempt, err, backoff)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
logger.Info("All servers restarted successfully.")
return nil
}
return fmt.Errorf("failed to start servers during restart after %d attempts: %w", maxAttempts, lastErr)
}
// List returns all registered server instances.
func (sm *serverManager) List() []Instance {
sm.mu.RLock()
defer sm.mu.RUnlock()
instances := make([]Instance, 0, len(sm.instances))
for _, instance := range sm.instances {
instances = append(instances, instance)
}
return instances
}
// RegisterShutdownCallback registers a callback to be called during shutdown.
func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) {
sm.callbacksMu.Lock()
defer sm.callbacksMu.Unlock()
sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb)
}
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
func (sm *serverManager) ServeWithGracefulShutdown() error {
// Start all servers
if err := sm.StartAll(); err != nil {
return fmt.Errorf("failed to start servers: %w", err)
}
logger.Info("All servers started. Waiting for shutdown signal...")
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
sig := <-sigChan
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
// Create context with timeout for shutdown
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return sm.StopAllWithContext(ctx)
}
// serverInstance is a concrete implementation of the Instance interface.
// It wraps gracefulServer to provide graceful shutdown capabilities.
type serverInstance struct {
cfg Config
gracefulServer *gracefulServer
certFile string // Path to certificate file (may be persistent for self-signed)
keyFile string // Path to key file (may be persistent for self-signed)
mu sync.RWMutex
running bool
serverErr chan error
}
// newInstance creates a new, unstarted server instance from a config.
func newInstance(cfg Config) (*serverInstance, error) {
if cfg.Handler == nil {
return nil, fmt.Errorf("handler cannot be nil")
}
// Set default timeouts
if cfg.ShutdownTimeout == 0 {
cfg.ShutdownTimeout = 30 * time.Second
}
if cfg.DrainTimeout == 0 {
cfg.DrainTimeout = 25 * time.Second
}
if cfg.ReadTimeout == 0 {
cfg.ReadTimeout = 15 * time.Second
}
if cfg.WriteTimeout == 0 {
cfg.WriteTimeout = 15 * time.Second
}
if cfg.IdleTimeout == 0 {
cfg.IdleTimeout = 60 * time.Second
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
var handler = cfg.Handler
// Wrap with GZIP handler if enabled
if cfg.GZIP {
gz, err := gzhttp.NewWrapper()
if err != nil {
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
}
handler = gz(handler)
}
// Wrap with the panic recovery middleware
handler = middleware.PanicRecovery(handler)
// Configure TLS if any TLS option is enabled
tlsConfig, certFile, keyFile, err := configureTLS(cfg)
if err != nil {
return nil, fmt.Errorf("failed to configure TLS: %w", err)
}
// Create gracefulServer
gracefulSrv := &gracefulServer{
server: &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
IdleTimeout: cfg.IdleTimeout,
TLSConfig: tlsConfig,
},
shutdownTimeout: cfg.ShutdownTimeout,
drainTimeout: cfg.DrainTimeout,
shutdownComplete: make(chan struct{}),
}
return &serverInstance{
cfg: cfg,
gracefulServer: gracefulSrv,
certFile: certFile,
keyFile: keyFile,
serverErr: make(chan error, 1),
}, nil
}
// Start begins serving requests in a new goroutine.
func (s *serverInstance) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
}
// Determine if we're using TLS
useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS
// Wrap handler with request tracking
s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler)
go func() {
defer func() {
s.mu.Lock()
s.running = false
s.mu.Unlock()
logger.Info("Server '%s' stopped.", s.cfg.Name)
}()
var err error
protocol := "HTTP"
if useTLS {
protocol = "HTTPS"
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
// For AutoTLS, we need to use a TLS listener
if s.cfg.AutoTLS {
// Create listener
ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr)
if lnErr != nil {
err = fmt.Errorf("failed to create listener: %w", lnErr)
} else {
// Wrap with TLS
tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig)
err = s.gracefulServer.server.Serve(tlsListener)
}
} else {
// Use certificate files (regular SSL or self-signed)
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
}
} else {
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
err = s.gracefulServer.server.ListenAndServe()
}
// If the server stopped for a reason other than a graceful shutdown, log and report the error.
if err != nil && err != http.ErrServerClosed {
logger.Error("Server '%s' failed: %v", s.cfg.Name, err)
select {
case s.serverErr <- err:
default:
}
}
}()
s.running = true
// A small delay to allow the goroutine to start and potentially fail on binding.
time.Sleep(50 * time.Millisecond)
// Check if the server failed to start
select {
case err := <-s.serverErr:
s.running = false
return err
default:
}
return nil
}
// Stop gracefully shuts down the server.
func (s *serverInstance) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running {
return nil // Already stopped
}
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
err := s.gracefulServer.shutdown(ctx)
if err == nil {
s.running = false
}
return err
}
// Addr returns the network address the server is listening on.
func (s *serverInstance) Addr() string {
return s.gracefulServer.server.Addr
}
// Name returns the server instance name.
func (s *serverInstance) Name() string {
return s.cfg.Name
}
// HealthCheckHandler returns a handler that responds to health checks.
func (s *serverInstance) HealthCheckHandler() http.HandlerFunc {
return s.gracefulServer.healthCheckHandler()
}
// ReadinessHandler returns a handler for readiness checks.
func (s *serverInstance) ReadinessHandler() http.HandlerFunc {
return s.gracefulServer.readinessHandler()
}
// InFlightRequests returns the current number of in-flight requests.
func (s *serverInstance) InFlightRequests() int64 {
return s.gracefulServer.inFlightRequestsCount()
}
// IsShuttingDown returns true if the server is shutting down.
func (s *serverInstance) IsShuttingDown() bool {
return s.gracefulServer.isShutdown()
}
// Wait blocks until shutdown is complete.
func (s *serverInstance) Wait() {
s.gracefulServer.wait()
}

399
pkg/server/manager_test.go Normal file
View File

@@ -0,0 +1,399 @@
package server
import (
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// getFreePort asks the kernel for a free open port that is ready to use.
func getFreePort(t *testing.T) int {
t.Helper()
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
require.NoError(t, err)
l, err := net.ListenTCP("tcp", addr)
require.NoError(t, err)
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
func TestServerManagerLifecycle(t *testing.T) {
// Initialize logger for test output
logger.Init(true)
// Create a new server manager
sm := NewManager()
// Define a simple test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, World!"))
})
// Get a free port for the server to listen on to avoid conflicts
testPort := getFreePort(t)
// Add a new server configuration
serverConfig := Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: testHandler,
}
instance, err := sm.Add(serverConfig)
require.NoError(t, err, "should be able to add a new server")
require.NotNil(t, instance, "added instance should not be nil")
// --- Test StartAll ---
err = sm.StartAll()
require.NoError(t, err, "StartAll should not return an error")
// Give the server a moment to start up
time.Sleep(100 * time.Millisecond)
// --- Verify Server is Running ---
client := &http.Client{Timeout: 2 * time.Second}
url := fmt.Sprintf("http://localhost:%d", testPort)
resp, err := client.Get(url)
require.NoError(t, err, "should be able to make a request to the running server")
assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
resp.Body.Close()
assert.Equal(t, "Hello, World!", string(body), "response body should match expected value")
// --- Test Get ---
retrievedInstance, err := sm.Get("TestServer")
require.NoError(t, err, "should be able to get server by name")
assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same")
// --- Test List ---
instanceList := sm.List()
require.Len(t, instanceList, 1, "list should contain one instance")
assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same")
// --- Test StopAll ---
err = sm.StopAll()
require.NoError(t, err, "StopAll should not return an error")
// Give the server a moment to shut down
time.Sleep(100 * time.Millisecond)
// --- Verify Server is Stopped ---
_, err = client.Get(url)
require.Error(t, err, "should not be able to make a request to a stopped server")
// --- Test Remove ---
err = sm.Remove("TestServer")
require.NoError(t, err, "should be able to remove a server")
_, err = sm.Get("TestServer")
require.Error(t, err, "should not be able to get a removed server")
}
func TestManagerErrorCases(t *testing.T) {
logger.Init(true)
sm := NewManager()
testPort := getFreePort(t)
// --- Test Add Duplicate Name ---
config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()}
_, err := sm.Add(config1)
require.NoError(t, err)
config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()}
_, err = sm.Add(config2)
require.Error(t, err, "should not be able to add a server with a duplicate name")
// --- Test Get Non-existent ---
_, err = sm.Get("NonExistent")
require.Error(t, err, "should get an error for a non-existent server")
// --- Test Add with Nil Handler ---
config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil}
_, err = sm.Add(config3)
require.Error(t, err, "should not be able to add a server with a nil handler")
}
func TestGracefulShutdown(t *testing.T) {
logger.Init(true)
sm := NewManager()
requestsHandled := 0
var requestsMu sync.Mutex
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsMu.Lock()
requestsHandled++
requestsMu.Unlock()
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
})
testPort := getFreePort(t)
instance, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: handler,
DrainTimeout: 2 * time.Second,
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Send some concurrent requests
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://localhost:%d", testPort)
resp, err := client.Get(url)
if err == nil {
resp.Body.Close()
}
}()
}
// Wait a bit for requests to start
time.Sleep(50 * time.Millisecond)
// Check in-flight requests
inFlight := instance.InFlightRequests()
assert.Greater(t, inFlight, int64(0), "Should have in-flight requests")
// Stop the server
err = sm.StopAll()
require.NoError(t, err)
// Wait for all requests to complete
wg.Wait()
// Verify all requests were handled
requestsMu.Lock()
handled := requestsHandled
requestsMu.Unlock()
assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled")
// Verify no in-flight requests
assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown")
}
func TestHealthAndReadinessEndpoints(t *testing.T) {
logger.Init(true)
sm := NewManager()
mux := http.NewServeMux()
testPort := getFreePort(t)
instance, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: mux,
})
require.NoError(t, err)
// Add health and readiness endpoints
mux.HandleFunc("/health", instance.HealthCheckHandler())
mux.HandleFunc("/ready", instance.ReadinessHandler())
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
client := &http.Client{Timeout: 2 * time.Second}
baseURL := fmt.Sprintf("http://localhost:%d", testPort)
// Test health endpoint
resp, err := client.Get(baseURL + "/health")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
assert.Contains(t, string(body), "healthy")
// Test readiness endpoint
resp, err = client.Get(baseURL + "/ready")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ = io.ReadAll(resp.Body)
resp.Body.Close()
assert.Contains(t, string(body), "ready")
assert.Contains(t, string(body), "in_flight_requests")
// Stop the server
sm.StopAll()
}
func TestRequestRejectionDuringShutdown(t *testing.T) {
logger.Init(true)
sm := NewManager()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond)
w.WriteHeader(http.StatusOK)
})
testPort := getFreePort(t)
_, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: handler,
DrainTimeout: 1 * time.Second,
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Start shutdown in background
go func() {
time.Sleep(50 * time.Millisecond)
sm.StopAll()
}()
// Give shutdown time to start
time.Sleep(100 * time.Millisecond)
// Try to make a request after shutdown started
client := &http.Client{Timeout: 2 * time.Second}
url := fmt.Sprintf("http://localhost:%d", testPort)
resp, err := client.Get(url)
// The request should either fail (connection refused) or get 503
if err == nil {
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown")
resp.Body.Close()
}
}
func TestShutdownCallbacks(t *testing.T) {
logger.Init(true)
sm := NewManager()
callbackExecuted := false
var callbackMu sync.Mutex
sm.RegisterShutdownCallback(func(ctx context.Context) error {
callbackMu.Lock()
callbackExecuted = true
callbackMu.Unlock()
return nil
})
testPort := getFreePort(t)
_, err := sm.Add(Config{
Name: "TestServer",
Host: "localhost",
Port: testPort,
Handler: http.NewServeMux(),
})
require.NoError(t, err)
err = sm.StartAll()
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = sm.StopAll()
require.NoError(t, err)
callbackMu.Lock()
executed := callbackExecuted
callbackMu.Unlock()
assert.True(t, executed, "Shutdown callback should have been executed")
}
func TestSelfSignedSSLCertificateReuse(t *testing.T) {
logger.Init(true)
// Get expected cert directory location
cacheDir, err := os.UserCacheDir()
require.NoError(t, err)
certDir := filepath.Join(cacheDir, "resolvespec", "certs")
host := "localhost"
certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host))
keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host))
// Clean up any existing cert files from previous tests
os.Remove(certFile)
os.Remove(keyFile)
// First server creation - should generate new certificates
sm1 := NewManager()
testPort1 := getFreePort(t)
_, err = sm1.Add(Config{
Name: "SSLTestServer1",
Host: host,
Port: testPort1,
Handler: http.NewServeMux(),
SelfSignedSSL: true,
ShutdownTimeout: 5 * time.Second,
})
require.NoError(t, err)
// Verify certificates were created
_, err = os.Stat(certFile)
require.NoError(t, err, "certificate file should exist after first creation")
_, err = os.Stat(keyFile)
require.NoError(t, err, "key file should exist after first creation")
// Get modification time of cert file
info1, err := os.Stat(certFile)
require.NoError(t, err)
modTime1 := info1.ModTime()
// Wait a bit to ensure different modification times
time.Sleep(100 * time.Millisecond)
// Second server creation - should reuse existing certificates
sm2 := NewManager()
testPort2 := getFreePort(t)
_, err = sm2.Add(Config{
Name: "SSLTestServer2",
Host: host,
Port: testPort2,
Handler: http.NewServeMux(),
SelfSignedSSL: true,
ShutdownTimeout: 5 * time.Second,
})
require.NoError(t, err)
// Get modification time of cert file after second creation
info2, err := os.Stat(certFile)
require.NoError(t, err)
modTime2 := info2.ModTime()
// Verify the certificate was reused (same modification time)
assert.Equal(t, modTime1, modTime2, "certificate should be reused, not regenerated")
// Clean up
sm1.StopAll()
sm2.StopAll()
}

View File

@@ -1,296 +0,0 @@
package server
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// GracefulServer wraps http.Server with graceful shutdown capabilities
type GracefulServer struct {
server *http.Server
shutdownTimeout time.Duration
drainTimeout time.Duration
inFlightRequests atomic.Int64
isShuttingDown atomic.Bool
shutdownOnce sync.Once
shutdownComplete chan struct{}
}
// Config holds configuration for the graceful server
type Config struct {
// Addr is the server address (e.g., ":8080")
Addr string
// Handler is the HTTP handler
Handler http.Handler
// ShutdownTimeout is the maximum time to wait for graceful shutdown
// Default: 30 seconds
ShutdownTimeout time.Duration
// DrainTimeout is the time to wait for in-flight requests to complete
// before forcing shutdown. Default: 25 seconds
DrainTimeout time.Duration
// ReadTimeout is the maximum duration for reading the entire request
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out writes of the response
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the next request
IdleTimeout time.Duration
}
// NewGracefulServer creates a new graceful server
func NewGracefulServer(config Config) *GracefulServer {
if config.ShutdownTimeout == 0 {
config.ShutdownTimeout = 30 * time.Second
}
if config.DrainTimeout == 0 {
config.DrainTimeout = 25 * time.Second
}
if config.ReadTimeout == 0 {
config.ReadTimeout = 10 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 10 * time.Second
}
if config.IdleTimeout == 0 {
config.IdleTimeout = 120 * time.Second
}
gs := &GracefulServer{
server: &http.Server{
Addr: config.Addr,
Handler: config.Handler,
ReadTimeout: config.ReadTimeout,
WriteTimeout: config.WriteTimeout,
IdleTimeout: config.IdleTimeout,
},
shutdownTimeout: config.ShutdownTimeout,
drainTimeout: config.DrainTimeout,
shutdownComplete: make(chan struct{}),
}
return gs
}
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if shutting down
if gs.isShuttingDown.Load() {
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
return
}
// Increment in-flight counter
gs.inFlightRequests.Add(1)
defer gs.inFlightRequests.Add(-1)
// Serve the request
next.ServeHTTP(w, r)
})
}
// ListenAndServe starts the server and handles graceful shutdown
func (gs *GracefulServer) ListenAndServe() error {
// Wrap handler with request tracking
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
// Start server in goroutine
serverErr := make(chan error, 1)
go func() {
logger.Info("Starting server on %s", gs.server.Addr)
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
serverErr <- err
}
close(serverErr)
}()
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
select {
case err := <-serverErr:
return err
case sig := <-sigChan:
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
return gs.Shutdown(context.Background())
}
}
// Shutdown performs graceful shutdown with request draining
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
var shutdownErr error
gs.shutdownOnce.Do(func() {
logger.Info("Starting graceful shutdown...")
// Mark as shutting down (new requests will be rejected)
gs.isShuttingDown.Store(true)
// Create context with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
defer cancel()
// Wait for in-flight requests to complete (with drain timeout)
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
defer drainCancel()
shutdownErr = gs.drainRequests(drainCtx)
if shutdownErr != nil {
logger.Error("Error draining requests: %v", shutdownErr)
}
// Shutdown the server
logger.Info("Shutting down HTTP server...")
if err := gs.server.Shutdown(shutdownCtx); err != nil {
logger.Error("Error shutting down server: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
logger.Info("Graceful shutdown complete")
close(gs.shutdownComplete)
})
return shutdownErr
}
// drainRequests waits for in-flight requests to complete
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
startTime := time.Now()
for {
inFlight := gs.inFlightRequests.Load()
if inFlight == 0 {
logger.Info("All requests drained in %v", time.Since(startTime))
return nil
}
select {
case <-ctx.Done():
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
case <-ticker.C:
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
}
}
}
// InFlightRequests returns the current number of in-flight requests
func (gs *GracefulServer) InFlightRequests() int64 {
return gs.inFlightRequests.Load()
}
// IsShuttingDown returns true if the server is shutting down
func (gs *GracefulServer) IsShuttingDown() bool {
return gs.isShuttingDown.Load()
}
// Wait blocks until shutdown is complete
func (gs *GracefulServer) Wait() {
<-gs.shutdownComplete
}
// HealthCheckHandler returns a handler that responds to health checks
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.IsShuttingDown() {
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(`{"status":"healthy"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
}
}
// ReadinessHandler returns a handler for readiness checks
// Includes in-flight request count
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if gs.IsShuttingDown() {
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
return
}
inFlight := gs.InFlightRequests()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
}
}
// ShutdownCallback is a function called during shutdown
type ShutdownCallback func(context.Context) error
// shutdownCallbacks stores registered shutdown callbacks
var (
shutdownCallbacks []ShutdownCallback
shutdownCallbacksMu sync.Mutex
)
// RegisterShutdownCallback registers a callback to be called during shutdown
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
func RegisterShutdownCallback(cb ShutdownCallback) {
shutdownCallbacksMu.Lock()
defer shutdownCallbacksMu.Unlock()
shutdownCallbacks = append(shutdownCallbacks, cb)
}
// executeShutdownCallbacks runs all registered shutdown callbacks
func executeShutdownCallbacks(ctx context.Context) error {
shutdownCallbacksMu.Lock()
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
copy(callbacks, shutdownCallbacks)
shutdownCallbacksMu.Unlock()
var errors []error
for i, cb := range callbacks {
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
if err := cb(ctx); err != nil {
logger.Error("Shutdown callback %d failed: %v", i+1, err)
errors = append(errors, err)
}
}
if len(errors) > 0 {
return fmt.Errorf("shutdown callbacks failed: %v", errors)
}
return nil
}
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
// Execute callbacks first
if err := executeShutdownCallbacks(ctx); err != nil {
logger.Error("Error executing shutdown callbacks: %v", err)
}
// Then shutdown the server
return gs.Shutdown(ctx)
}

View File

@@ -1,231 +0,0 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
func TestGracefulServerTrackRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Start some requests
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}()
}
// Wait a bit for requests to start
time.Sleep(10 * time.Millisecond)
// Check in-flight count
inFlight := srv.InFlightRequests()
if inFlight == 0 {
t.Error("Should have in-flight requests")
}
// Wait for all requests to complete
wg.Wait()
// Check that counter is back to zero
inFlight = srv.InFlightRequests()
if inFlight != 0 {
t.Errorf("In-flight requests should be 0, got %d", inFlight)
}
}
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
})
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
// Mark as shutting down
srv.isShuttingDown.Store(true)
// Try to make a request
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should get 503
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
}
func TestHealthCheckHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.HealthCheckHandler()
// Healthy
t.Run("Healthy", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
if w.Body.String() != `{"status":"healthy"}` {
t.Errorf("Unexpected body: %s", w.Body.String())
}
})
// Shutting down
t.Run("ShuttingDown", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestReadinessHandler(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
})
handler := srv.ReadinessHandler()
// Ready with no in-flight requests
t.Run("Ready", func(t *testing.T) {
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.Code)
}
body := w.Body.String()
if body != `{"ready":true,"in_flight_requests":0}` {
t.Errorf("Unexpected body: %s", body)
}
})
// Not ready during shutdown
t.Run("NotReady", func(t *testing.T) {
srv.isShuttingDown.Store(true)
req := httptest.NewRequest("GET", "/ready", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("Expected 503, got %d", w.Code)
}
})
}
func TestShutdownCallbacks(t *testing.T) {
callbackExecuted := false
RegisterShutdownCallback(func(ctx context.Context) error {
callbackExecuted = true
return nil
})
ctx := context.Background()
err := executeShutdownCallbacks(ctx)
if err != nil {
t.Errorf("executeShutdownCallbacks() error = %v", err)
}
if !callbackExecuted {
t.Error("Shutdown callback was not executed")
}
// Reset for other tests
shutdownCallbacks = nil
}
func TestDrainRequests(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 1 * time.Second,
})
// Simulate in-flight requests
srv.inFlightRequests.Add(3)
// Start draining in background
go func() {
time.Sleep(100 * time.Millisecond)
// Simulate requests completing
srv.inFlightRequests.Add(-3)
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := srv.drainRequests(ctx)
if err != nil {
t.Errorf("drainRequests() error = %v", err)
}
if srv.InFlightRequests() != 0 {
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
}
}
func TestDrainRequestsTimeout(t *testing.T) {
srv := NewGracefulServer(Config{
Addr: ":0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
DrainTimeout: 100 * time.Millisecond,
})
// Simulate in-flight requests that don't complete
srv.inFlightRequests.Add(5)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := srv.drainRequests(ctx)
if err == nil {
t.Error("drainRequests() should timeout with error")
}
// Cleanup
srv.inFlightRequests.Add(-5)
}
func TestGetClientIP(t *testing.T) {
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
// Including here for completeness of server tests
}

View File

@@ -0,0 +1,439 @@
# StaticWeb - Interface-Driven Static File Server
StaticWeb is a flexible, interface-driven Go package for serving static files over HTTP. It supports multiple filesystem backends (local, zip, embedded) and provides pluggable policies for caching, MIME types, and fallback strategies.
## Features
- **Router-Agnostic**: Works with any HTTP router through standard `http.Handler`
- **Multiple Filesystem Providers**: Local directories, zip files, embedded filesystems
- **Pluggable Policies**: Customizable cache, MIME type, and fallback strategies
- **Thread-Safe**: Safe for concurrent use
- **Resource Management**: Proper lifecycle management with `Close()` methods
- **Extensible**: Easy to add new providers and policies
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/server/staticweb
```
## Quick Start
### Basic Usage
Serve files from a local directory:
```go
import "github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
// Create service
service := staticweb.NewService(nil)
// Mount a local directory
provider, _ := staticweb.LocalProvider("./public")
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
})
// Use with any router
router.PathPrefix("/").Handler(service.Handler())
```
### Single Page Application (SPA)
Serve an SPA with HTML fallback routing:
```go
service := staticweb.NewService(nil)
provider, _ := staticweb.LocalProvider("./dist")
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: provider,
FallbackStrategy: staticweb.HTMLFallback("index.html"),
})
// API routes take precedence (registered first)
router.HandleFunc("/api/users", usersHandler)
router.HandleFunc("/api/posts", postsHandler)
// Static files handle all other routes
router.PathPrefix("/").Handler(service.Handler())
```
## Filesystem Providers
### Local Directory
Serve files from a local filesystem directory:
```go
provider, err := staticweb.LocalProvider("/var/www/static")
```
### Zip File
Serve files from a zip archive:
```go
provider, err := staticweb.ZipProvider("./static.zip")
```
### Embedded Filesystem
Serve files from Go's embedded filesystem:
```go
//go:embed assets
var assets embed.FS
// Direct embedded FS
provider, err := staticweb.EmbedProvider(&assets, "")
// Or from a zip file within embedded FS
provider, err := staticweb.EmbedProvider(&assets, "assets.zip")
```
## Cache Policies
### Simple Cache
Single TTL for all files:
```go
cachePolicy := staticweb.SimpleCache(3600) // 1 hour
```
### Extension-Based Cache
Different TTLs per file type:
```go
rules := map[string]int{
".html": 3600, // 1 hour
".js": 86400, // 1 day
".css": 86400, // 1 day
".png": 604800, // 1 week
}
cachePolicy := staticweb.ExtensionCache(rules, 3600) // default 1 hour
```
### No Cache
Disable caching entirely:
```go
cachePolicy := staticweb.NoCache()
```
## Fallback Strategies
### HTML Fallback
Serve index.html for non-asset requests (SPA routing):
```go
fallback := staticweb.HTMLFallback("index.html")
```
### Extension-Based Fallback
Skip fallback for known static assets:
```go
fallback := staticweb.DefaultExtensionFallback("index.html")
```
Custom extensions:
```go
staticExts := []string{".js", ".css", ".png", ".jpg"}
fallback := staticweb.ExtensionFallback(staticExts, "index.html")
```
## Configuration
### Service Configuration
```go
config := &staticweb.ServiceConfig{
DefaultCacheTime: 3600,
DefaultMIMETypes: map[string]string{
".webp": "image/webp",
".wasm": "application/wasm",
},
}
service := staticweb.NewService(config)
```
### Mount Configuration
```go
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
CachePolicy: cachePolicy, // Optional
MIMEResolver: mimeResolver, // Optional
FallbackStrategy: fallbackStrategy, // Optional
})
```
## Advanced Usage
### Multiple Mount Points
Serve different directories at different URL prefixes with different policies:
```go
service := staticweb.NewService(nil)
// Long-lived assets
assetsProvider, _ := staticweb.LocalProvider("./assets")
service.Mount(staticweb.MountConfig{
URLPrefix: "/assets",
Provider: assetsProvider,
CachePolicy: staticweb.SimpleCache(604800), // 1 week
})
// Frequently updated HTML
htmlProvider, _ := staticweb.LocalProvider("./public")
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: htmlProvider,
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
})
```
### Custom MIME Types
```go
mimeResolver := staticweb.DefaultMIMEResolver()
mimeResolver.RegisterMIMEType(".webp", "image/webp")
mimeResolver.RegisterMIMEType(".wasm", "application/wasm")
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
MIMEResolver: mimeResolver,
})
```
### Resource Cleanup
Always close the service when done:
```go
service := staticweb.NewService(nil)
defer service.Close()
// ... mount and use service ...
```
Or unmount individual mount points:
```go
service.Unmount("/static")
```
### Reloading/Refreshing Content
Reload providers to pick up changes from the underlying filesystem. This is particularly useful for zip files in development:
```go
// When zip file or directory contents change
err := service.Reload()
if err != nil {
log.Printf("Failed to reload: %v", err)
}
```
Providers that support reloading:
- **ZipFSProvider**: Reopens the zip file to pick up changes
- **LocalFSProvider**: Refreshes the directory view (automatically picks up changes)
- **EmbedFSProvider**: Not reloadable (embedded at compile time)
You can also reload individual providers:
```go
if reloadable, ok := provider.(staticweb.ReloadableProvider); ok {
err := reloadable.Reload()
if err != nil {
log.Printf("Failed to reload: %v", err)
}
}
```
**Development Workflow Example:**
```go
service := staticweb.NewService(nil)
provider, _ := staticweb.ZipProvider("./dist.zip")
service.Mount(staticweb.MountConfig{
URLPrefix: "/app",
Provider: provider,
})
// In development, reload when dist.zip is rebuilt
go func() {
watcher := fsnotify.NewWatcher()
watcher.Add("./dist.zip")
for range watcher.Events {
log.Println("Reloading static files...")
if err := service.Reload(); err != nil {
log.Printf("Reload failed: %v", err)
}
}
}()
```
## Router Integration
### Gorilla Mux
```go
router := mux.NewRouter()
router.HandleFunc("/api/users", usersHandler)
router.PathPrefix("/").Handler(service.Handler())
```
### Standard http.ServeMux
```go
http.Handle("/api/", apiHandler)
http.Handle("/", service.Handler())
```
### BunRouter
```go
router.GET("/api/users", usersHandler)
router.GET("/*path", bunrouter.HTTPHandlerFunc(service.Handler()))
```
## Architecture
### Core Interfaces
#### FileSystemProvider
Abstracts the source of files:
```go
type FileSystemProvider interface {
Open(name string) (fs.File, error)
Close() error
Type() string
}
```
Implementations:
- `LocalFSProvider` - Local directories
- `ZipFSProvider` - Zip archives
- `EmbedFSProvider` - Embedded filesystems
#### CachePolicy
Defines caching behavior:
```go
type CachePolicy interface {
GetCacheTime(path string) int
GetCacheHeaders(path string) map[string]string
}
```
Implementations:
- `SimpleCachePolicy` - Single TTL
- `ExtensionBasedCachePolicy` - Per-extension TTL
- `NoCachePolicy` - Disable caching
#### FallbackStrategy
Handles missing files:
```go
type FallbackStrategy interface {
ShouldFallback(path string) bool
GetFallbackPath(path string) string
}
```
Implementations:
- `NoFallback` - Return 404
- `HTMLFallbackStrategy` - SPA routing
- `ExtensionBasedFallback` - Skip known assets
#### MIMETypeResolver
Determines Content-Type:
```go
type MIMETypeResolver interface {
GetMIMEType(path string) string
RegisterMIMEType(extension, mimeType string)
}
```
Implementations:
- `DefaultMIMEResolver` - Common web types
- `ConfigurableMIMEResolver` - Custom mappings
## Testing
### Mock Providers
```go
import staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"index.html": []byte("<html>test</html>"),
"app.js": []byte("console.log('test')"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: provider,
})
```
### Test Helpers
```go
req := httptest.NewRequest("GET", "/index.html", nil)
rec := httptest.NewRecorder()
service.Handler().ServeHTTP(rec, req)
// Assert response
assert.Equal(t, 200, rec.Code)
```
## Future Features
The interface-driven design allows for easy extensibility:
### Planned Providers
- **HTTPFSProvider**: Fetch files from remote HTTP servers with local caching
- **S3FSProvider**: Serve files from S3-compatible storage
- **CompositeProvider**: Fallback chain across multiple providers
- **MemoryProvider**: In-memory filesystem for testing
### Planned Policies
- **TimedCachePolicy**: Different cache times by time of day
- **ConditionalCachePolicy**: Smart cache based on file size/type
- **RegexFallbackStrategy**: Pattern-based routing
## License
See the main repository for license information.
## Contributing
Contributions are welcome! The interface-driven design makes it easy to add new providers and policies without modifying existing code.

View File

@@ -0,0 +1,99 @@
package staticweb
import (
"embed"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb/policies"
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb/providers"
)
// ServiceConfig configures the static file service.
type ServiceConfig struct {
// DefaultCacheTime is the default cache duration in seconds.
// Used when a mount point doesn't specify a custom CachePolicy.
// Default: 172800 (48 hours)
DefaultCacheTime int
// DefaultMIMETypes is a map of file extensions to MIME types.
// These are added to the default MIME resolver.
// Extensions should include the leading dot (e.g., ".webp").
DefaultMIMETypes map[string]string
}
// DefaultServiceConfig returns a ServiceConfig with sensible defaults.
func DefaultServiceConfig() *ServiceConfig {
return &ServiceConfig{
DefaultCacheTime: 172800, // 48 hours
DefaultMIMETypes: make(map[string]string),
}
}
// Validate checks if the ServiceConfig is valid.
func (c *ServiceConfig) Validate() error {
if c.DefaultCacheTime < 0 {
return fmt.Errorf("DefaultCacheTime cannot be negative")
}
return nil
}
// Helper constructor functions for providers
// LocalProvider creates a FileSystemProvider for a local directory.
func LocalProvider(path string) (FileSystemProvider, error) {
return providers.NewLocalFSProvider(path)
}
// ZipProvider creates a FileSystemProvider for a zip file.
func ZipProvider(zipPath string) (FileSystemProvider, error) {
return providers.NewZipFSProvider(zipPath)
}
// EmbedProvider creates a FileSystemProvider for an embedded filesystem.
// If zipFile is empty, the embedded FS is used directly.
// If zipFile is specified, it's treated as a path to a zip file within the embedded FS.
// The embedFS parameter can be any fs.FS, but is typically *embed.FS.
func EmbedProvider(embedFS *embed.FS, zipFile string) (FileSystemProvider, error) {
return providers.NewEmbedFSProvider(embedFS, zipFile)
}
// Policy constructor functions
// SimpleCache creates a simple cache policy with the given TTL in seconds.
func SimpleCache(seconds int) CachePolicy {
return policies.NewSimpleCachePolicy(seconds)
}
// ExtensionCache creates an extension-based cache policy.
// rules maps file extensions (with leading dot) to cache times in seconds.
// defaultTime is used for files that don't match any rule.
func ExtensionCache(rules map[string]int, defaultTime int) CachePolicy {
return policies.NewExtensionBasedCachePolicy(rules, defaultTime)
}
// NoCache creates a cache policy that disables all caching.
func NoCache() CachePolicy {
return policies.NewNoCachePolicy()
}
// HTMLFallback creates a fallback strategy for SPAs that serves the given index file.
func HTMLFallback(indexFile string) FallbackStrategy {
return policies.NewHTMLFallbackStrategy(indexFile)
}
// ExtensionFallback creates an extension-based fallback strategy.
// staticExtensions is a list of file extensions that should NOT use fallback.
// fallbackPath is the file to serve when fallback is triggered.
func ExtensionFallback(staticExtensions []string, fallbackPath string) FallbackStrategy {
return policies.NewExtensionBasedFallback(staticExtensions, fallbackPath)
}
// DefaultExtensionFallback creates an extension-based fallback with common web asset extensions.
func DefaultExtensionFallback(fallbackPath string) FallbackStrategy {
return policies.NewDefaultExtensionBasedFallback(fallbackPath)
}
// DefaultMIMEResolver creates a MIME resolver with common web file types.
func DefaultMIMEResolver() MIMETypeResolver {
return policies.NewDefaultMIMEResolver()
}

View File

@@ -0,0 +1,60 @@
package staticweb_test
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
)
// Example_reload demonstrates reloading content when files change.
func Example_reload() {
service := staticweb.NewService(nil)
// Create a provider
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"version.txt": []byte("v1.0.0"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
})
// Simulate updating the file
provider.AddFile("version.txt", []byte("v2.0.0"))
// Reload to pick up changes (in real usage with zip files)
err := service.Reload()
if err != nil {
fmt.Printf("Failed to reload: %v\n", err)
} else {
fmt.Println("Successfully reloaded static files")
}
// Output: Successfully reloaded static files
}
// Example_reloadZip demonstrates reloading a zip file provider.
func Example_reloadZip() {
service := staticweb.NewService(nil)
// In production, you would use:
// provider, _ := staticweb.ZipProvider("./dist.zip")
// For this example, we use a mock
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"app.js": []byte("console.log('v1')"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/app",
Provider: provider,
})
fmt.Println("Serving from zip file")
// When the zip file is updated, call Reload()
// service.Reload()
// Output: Serving from zip file
}

View File

@@ -0,0 +1,138 @@
package staticweb_test
import (
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
"github.com/gorilla/mux"
)
// Example_basic demonstrates serving files from a local directory.
func Example_basic() {
service := staticweb.NewService(nil)
// Using mock provider for example purposes
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"index.html": []byte("<html>test</html>"),
})
_ = service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
})
router := mux.NewRouter()
router.PathPrefix("/").Handler(service.Handler())
fmt.Println("Serving files from ./public at /static")
// Output: Serving files from ./public at /static
}
// Example_spa demonstrates an SPA with HTML fallback routing.
func Example_spa() {
service := staticweb.NewService(nil)
// Using mock provider for example purposes
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"index.html": []byte("<html>app</html>"),
})
_ = service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: provider,
FallbackStrategy: staticweb.HTMLFallback("index.html"),
})
router := mux.NewRouter()
// API routes take precedence
router.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("users"))
})
// Static files handle all other routes
router.PathPrefix("/").Handler(service.Handler())
fmt.Println("SPA with fallback to index.html")
// Output: SPA with fallback to index.html
}
// Example_multiple demonstrates multiple mount points with different policies.
func Example_multiple() {
service := staticweb.NewService(&staticweb.ServiceConfig{
DefaultCacheTime: 3600,
})
// Assets with long cache (using mock for example)
assetsProvider := staticwebtesting.NewMockProvider(map[string][]byte{
"app.js": []byte("console.log('test')"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/assets",
Provider: assetsProvider,
CachePolicy: staticweb.SimpleCache(604800), // 1 week
})
// HTML with short cache (using mock for example)
htmlProvider := staticwebtesting.NewMockProvider(map[string][]byte{
"index.html": []byte("<html>test</html>"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: htmlProvider,
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
})
fmt.Println("Multiple mount points configured")
// Output: Multiple mount points configured
}
// Example_zip demonstrates serving from a zip file (concept).
func Example_zip() {
service := staticweb.NewService(nil)
// For actual usage, you would use:
// provider, err := staticweb.ZipProvider("./static.zip")
// For this example, we use a mock
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"file.txt": []byte("content"),
})
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: provider,
})
fmt.Println("Serving from zip file")
// Output: Serving from zip file
}
// Example_extensionCache demonstrates extension-based caching.
func Example_extensionCache() {
service := staticweb.NewService(nil)
// Using mock provider for example purposes
provider := staticwebtesting.NewMockProvider(map[string][]byte{
"index.html": []byte("<html>test</html>"),
"app.js": []byte("console.log('test')"),
})
// Different cache times per file type
cacheRules := map[string]int{
".html": 3600, // 1 hour
".js": 86400, // 1 day
".css": 86400, // 1 day
".png": 604800, // 1 week
}
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: provider,
CachePolicy: staticweb.ExtensionCache(cacheRules, 3600), // default 1 hour
})
fmt.Println("Extension-based caching configured")
// Output: Extension-based caching configured
}

View File

@@ -0,0 +1,130 @@
package staticweb
import (
"io/fs"
"net/http"
)
// FileSystemProvider abstracts the source of files (local, zip, embedded, future: http, s3)
// Implementations must be safe for concurrent use.
type FileSystemProvider interface {
// Open opens the named file.
// The name is always a slash-separated path relative to the filesystem root.
Open(name string) (fs.File, error)
// Close releases any resources held by the provider.
// After Close is called, the provider should not be used.
Close() error
// Type returns the provider type (e.g., "local", "zip", "embed", "http", "s3").
// This is primarily for debugging and logging purposes.
Type() string
}
// ReloadableProvider is an optional interface that providers can implement
// to support reloading/refreshing their content.
// This is useful for development workflows where the underlying files may change.
type ReloadableProvider interface {
FileSystemProvider
// Reload refreshes the provider's content from the underlying source.
// For zip files, this reopens the zip archive.
// For local directories, this refreshes the filesystem view.
// Returns an error if the reload fails.
Reload() error
}
// CachePolicy defines how files should be cached by browsers and proxies.
// Implementations must be safe for concurrent use.
type CachePolicy interface {
// GetCacheTime returns the cache duration in seconds for the given path.
// A value of 0 means no caching.
// A negative value can be used to indicate browser should revalidate.
GetCacheTime(path string) int
// GetCacheHeaders returns additional cache-related HTTP headers for the given path.
// Common headers include "Cache-Control", "Expires", "ETag", etc.
// Returns nil if no additional headers are needed.
GetCacheHeaders(path string) map[string]string
}
// MIMETypeResolver determines the Content-Type for files.
// Implementations must be safe for concurrent use.
type MIMETypeResolver interface {
// GetMIMEType returns the MIME type for the given file path.
// Returns empty string if the MIME type cannot be determined.
GetMIMEType(path string) string
// RegisterMIMEType registers a custom MIME type for the given file extension.
// The extension should include the leading dot (e.g., ".webp").
RegisterMIMEType(extension, mimeType string)
}
// FallbackStrategy handles requests for files that don't exist.
// This is commonly used for Single Page Applications (SPAs) that use client-side routing.
// Implementations must be safe for concurrent use.
type FallbackStrategy interface {
// ShouldFallback determines if a fallback should be attempted for the given path.
// Returns true if the request should be handled by fallback logic.
ShouldFallback(path string) bool
// GetFallbackPath returns the path to serve instead of the originally requested path.
// This is only called if ShouldFallback returns true.
GetFallbackPath(path string) string
}
// MountConfig configures a single mount point.
// A mount point connects a URL prefix to a filesystem provider with optional policies.
type MountConfig struct {
// URLPrefix is the URL path prefix where the filesystem should be mounted.
// Must start with "/" (e.g., "/static", "/", "/assets").
// Requests starting with this prefix will be handled by this mount point.
URLPrefix string
// Provider is the filesystem provider that supplies the files.
// Required.
Provider FileSystemProvider
// CachePolicy determines how files should be cached.
// If nil, the service's default cache policy is used.
CachePolicy CachePolicy
// MIMEResolver determines Content-Type headers for files.
// If nil, the service's default MIME resolver is used.
MIMEResolver MIMETypeResolver
// FallbackStrategy handles requests for missing files.
// If nil, no fallback is performed and 404 responses are returned.
FallbackStrategy FallbackStrategy
}
// StaticFileService manages multiple mount points and serves static files.
// The service is safe for concurrent use.
type StaticFileService interface {
// Mount adds a new mount point with the given configuration.
// Returns an error if the URLPrefix is already mounted or if the config is invalid.
Mount(config MountConfig) error
// Unmount removes the mount point at the given URL prefix.
// Returns an error if no mount point exists at that prefix.
// Automatically calls Close() on the provider to release resources.
Unmount(urlPrefix string) error
// ListMounts returns a sorted list of all mounted URL prefixes.
ListMounts() []string
// Reload reinitializes all filesystem providers.
// This can be used to pick up changes in the underlying filesystems.
// Not all providers may support reloading.
Reload() error
// Close releases all resources held by the service and all mounted providers.
// After Close is called, the service should not be used.
Close() error
// Handler returns an http.Handler that serves static files from all mount points.
// The handler performs longest-prefix matching to find the appropriate mount point.
// If no mount point matches, the handler returns without writing a response,
// allowing other handlers (like API routes) to process the request.
Handler() http.Handler
}

View File

@@ -0,0 +1,235 @@
package staticweb
import (
"fmt"
"io"
"io/fs"
"net/http"
"path"
"strings"
)
// mountPoint represents a mounted filesystem at a specific URL prefix.
type mountPoint struct {
urlPrefix string
provider FileSystemProvider
cachePolicy CachePolicy
mimeResolver MIMETypeResolver
fallbackStrategy FallbackStrategy
fileServer http.Handler
}
// newMountPoint creates a new mount point with the given configuration.
func newMountPoint(config MountConfig, defaults *ServiceConfig) (*mountPoint, error) {
if config.URLPrefix == "" {
return nil, fmt.Errorf("URLPrefix cannot be empty")
}
if !strings.HasPrefix(config.URLPrefix, "/") {
return nil, fmt.Errorf("URLPrefix must start with /")
}
if config.Provider == nil {
return nil, fmt.Errorf("provider cannot be nil")
}
mp := &mountPoint{
urlPrefix: config.URLPrefix,
provider: config.Provider,
cachePolicy: config.CachePolicy,
mimeResolver: config.MIMEResolver,
fallbackStrategy: config.FallbackStrategy,
}
// Apply defaults if policies are not specified
if mp.cachePolicy == nil && defaults != nil {
mp.cachePolicy = defaultCachePolicy(defaults.DefaultCacheTime)
}
if mp.mimeResolver == nil {
mp.mimeResolver = defaultMIMEResolver()
}
// Create an http.FileServer for serving files
mp.fileServer = http.FileServer(http.FS(config.Provider))
return mp, nil
}
// ServeHTTP handles HTTP requests for files in this mount point.
func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Strip the URL prefix to get the file path
filePath := strings.TrimPrefix(r.URL.Path, m.urlPrefix)
if filePath == "" {
filePath = "/"
}
// Clean the path
filePath = path.Clean(filePath)
// Try to open the file
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
if err != nil {
// File doesn't exist - check if we should use fallback
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
file, err = m.provider.Open(strings.TrimPrefix(fallbackPath, "/"))
if err == nil {
// Successfully opened fallback file
defer file.Close()
m.serveFile(w, r, fallbackPath, file)
return
}
}
// No fallback or fallback failed - return 404
http.NotFound(w, r)
return
}
defer file.Close()
// Serve the file
m.serveFile(w, r, filePath, file)
}
// serveFile serves a single file with appropriate headers.
func (m *mountPoint) serveFile(w http.ResponseWriter, r *http.Request, filePath string, file fs.File) {
// Get file info
stat, err := file.Stat()
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// If it's a directory, try to serve index.html
if stat.IsDir() {
indexPath := path.Join(filePath, "index.html")
indexFile, err := m.provider.Open(strings.TrimPrefix(indexPath, "/"))
if err != nil {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
defer indexFile.Close()
indexStat, err := indexFile.Stat()
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
filePath = indexPath
stat = indexStat
file = indexFile
}
// Set Content-Type header using MIME resolver
if m.mimeResolver != nil {
if mimeType := m.mimeResolver.GetMIMEType(filePath); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
}
// Apply cache policy
if m.cachePolicy != nil {
headers := m.cachePolicy.GetCacheHeaders(filePath)
for key, value := range headers {
w.Header().Set(key, value)
}
}
// Serve the content
if seeker, ok := file.(interface {
io.ReadSeeker
}); ok {
http.ServeContent(w, r, stat.Name(), stat.ModTime(), seeker)
} else {
// If the file doesn't support seeking, we need to read it all into memory
data, err := fs.ReadFile(m.provider, strings.TrimPrefix(filePath, "/"))
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
http.ServeContent(w, r, stat.Name(), stat.ModTime(), strings.NewReader(string(data)))
}
}
// Close releases resources held by the mount point.
func (m *mountPoint) Close() error {
if m.provider != nil {
return m.provider.Close()
}
return nil
}
// defaultCachePolicy creates a default simple cache policy.
func defaultCachePolicy(cacheTime int) CachePolicy {
// Import the policies package type - we'll need to use the concrete type
// For now, create a simple inline implementation
return &simpleCachePolicy{cacheTime: cacheTime}
}
// simpleCachePolicy is a simple inline implementation of CachePolicy
type simpleCachePolicy struct {
cacheTime int
}
func (p *simpleCachePolicy) GetCacheTime(path string) int {
return p.cacheTime
}
func (p *simpleCachePolicy) GetCacheHeaders(path string) map[string]string {
if p.cacheTime <= 0 {
return map[string]string{
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
}
return map[string]string{
"Cache-Control": fmt.Sprintf("public, max-age=%d", p.cacheTime),
}
}
// defaultMIMEResolver creates a default MIME resolver.
func defaultMIMEResolver() MIMETypeResolver {
// Import the policies package type - we'll need to use the concrete type
// For now, create a simple inline implementation
return &simpleMIMEResolver{
types: map[string]string{
".js": "application/javascript",
".mjs": "application/javascript",
".cjs": "application/javascript",
".css": "text/css",
".html": "text/html",
".htm": "text/html",
".json": "application/json",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
".txt": "text/plain",
},
}
}
// simpleMIMEResolver is a simple inline implementation of MIMETypeResolver
type simpleMIMEResolver struct {
types map[string]string
}
func (r *simpleMIMEResolver) GetMIMEType(filePath string) string {
ext := strings.ToLower(path.Ext(filePath))
if mimeType, ok := r.types[ext]; ok {
return mimeType
}
return ""
}
func (r *simpleMIMEResolver) RegisterMIMEType(extension, mimeType string) {
if !strings.HasPrefix(extension, ".") {
extension = "." + extension
}
r.types[strings.ToLower(extension)] = mimeType
}

View File

@@ -0,0 +1,592 @@
# StaticWeb Package Interface-Driven Refactoring Plan
## Overview
Refactor `pkg/server/staticweb` to be interface-driven, router-agnostic, and maintainable. This is a breaking change that replaces the existing API with a cleaner design.
## User Requirements
- ✅ Work with any server (not just Gorilla mux)
- ✅ Serve static files from zip files or directories
- ✅ Support embedded, local filesystems
- ✅ Interface-driven and maintainable architecture
- ✅ Struct-based configuration
- ✅ Breaking changes acceptable
- 🔮 Remote HTTP/HTTPS and S3 support (future feature)
## Design Principles
1. **Interface-first**: Define clear interfaces for all abstractions
2. **Composition over inheritance**: Combine small, focused components
3. **Router-agnostic**: Return standard `http.Handler` for universal compatibility
4. **Configurable policies**: Extract hardcoded behavior into pluggable strategies
5. **Resource safety**: Proper lifecycle management with `Close()` methods
6. **Testability**: Mock-friendly interfaces with clear boundaries
7. **Extensibility**: Easy to add new providers (HTTP, S3, etc.) in future
## Architecture Overview
### Core Interfaces (pkg/server/staticweb/interfaces.go)
```go
// FileSystemProvider abstracts file sources (local, zip, embedded, future: http, s3)
type FileSystemProvider interface {
Open(name string) (fs.File, error)
Close() error
Type() string
}
// CachePolicy defines caching behavior
type CachePolicy interface {
GetCacheTime(path string) int
GetCacheHeaders(path string) map[string]string
}
// MIMETypeResolver determines content types
type MIMETypeResolver interface {
GetMIMEType(path string) string
RegisterMIMEType(extension, mimeType string)
}
// FallbackStrategy handles missing files
type FallbackStrategy interface {
ShouldFallback(path string) bool
GetFallbackPath(path string) string
}
// StaticFileService manages mount points
type StaticFileService interface {
Mount(config MountConfig) error
Unmount(urlPrefix string) error
ListMounts() []string
Reload() error
Close() error
Handler() http.Handler // Router-agnostic integration
}
```
### Configuration (pkg/server/staticweb/config.go)
```go
// MountConfig configures a single mount point
type MountConfig struct {
URLPrefix string
Provider FileSystemProvider
CachePolicy CachePolicy // Optional, uses default if nil
MIMEResolver MIMETypeResolver // Optional, uses default if nil
FallbackStrategy FallbackStrategy // Optional, no fallback if nil
}
// ServiceConfig configures the service
type ServiceConfig struct {
DefaultCacheTime int // Default: 48 hours
DefaultMIMETypes map[string]string // Additional MIME types
}
```
## Implementation Plan
### Step 1: Create Core Interfaces
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/interfaces.go` (NEW)
Define all interfaces listed above. This establishes the contract for all components.
### Step 2: Implement Default Policies
**Directory**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/policies/` (NEW)
**File**: `policies/cache.go`
- `SimpleCachePolicy` - Single TTL for all files
- `ExtensionBasedCachePolicy` - Different TTL per file extension
- `NoCachePolicy` - Disables caching
**File**: `policies/mime.go`
- `DefaultMIMEResolver` - Standard web MIME types + stdlib
- `ConfigurableMIMEResolver` - User-defined mappings
- Migrate hardcoded MIME types from `InitMimeTypes()` (lines 60-76)
**File**: `policies/fallback.go`
- `NoFallback` - Returns 404 for missing files
- `HTMLFallbackStrategy` - SPA routing (serves index.html)
- `ExtensionBasedFallback` - Current behavior (checks extensions)
- Migrate logic from `StaticHTMLFallbackHandler()` (lines 241-285)
### Step 3: Implement FileSystem Providers
**Directory**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/providers/` (NEW)
**File**: `providers/local.go`
```go
type LocalFSProvider struct {
path string
fs fs.FS
}
func NewLocalFSProvider(path string) (*LocalFSProvider, error)
func (l *LocalFSProvider) Open(name string) (fs.File, error)
func (l *LocalFSProvider) Close() error
func (l *LocalFSProvider) Type() string
```
**File**: `providers/zip.go`
```go
type ZipFSProvider struct {
zipPath string
zipReader *zip.ReadCloser
zipFS *zipfs.ZipFS
mu sync.RWMutex
}
func NewZipFSProvider(zipPath string) (*ZipFSProvider, error)
func (z *ZipFSProvider) Open(name string) (fs.File, error)
func (z *ZipFSProvider) Close() error
func (z *ZipFSProvider) Type() string
```
- Integrates with existing `pkg/server/zipfs/zipfs.go`
- Manages zip file lifecycle properly
**File**: `providers/embed.go`
```go
type EmbedFSProvider struct {
embedFS *embed.FS
zipFile string // Optional: path within embedded FS to zip file
zipReader *zip.ReadCloser
fs fs.FS
mu sync.RWMutex
}
func NewEmbedFSProvider(embedFS *embed.FS, zipFile string) (*EmbedFSProvider, error)
func (e *EmbedFSProvider) Open(name string) (fs.File, error)
func (e *EmbedFSProvider) Close() error
func (e *EmbedFSProvider) Type() string
```
### Step 4: Implement Mount Point
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/mount.go` (NEW)
```go
type mountPoint struct {
urlPrefix string
provider FileSystemProvider
cachePolicy CachePolicy
mimeResolver MIMETypeResolver
fallbackStrategy FallbackStrategy
}
func newMountPoint(config MountConfig, defaults *ServiceConfig) (*mountPoint, error)
func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request)
func (m *mountPoint) Close() error
```
**Key behaviors**:
- Strips URL prefix before passing to provider
- Applies cache headers via `CachePolicy`
- Sets Content-Type via `MIMETypeResolver`
- Falls back via `FallbackStrategy` if file not found
- Integrates with `http.FileServer()` for actual serving
### Step 5: Implement Service
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/service.go` (NEW)
```go
type service struct {
mounts map[string]*mountPoint // urlPrefix -> mount
config *ServiceConfig
mu sync.RWMutex
}
func NewService(config *ServiceConfig) StaticFileService
func (s *service) Mount(config MountConfig) error
func (s *service) Unmount(urlPrefix string) error
func (s *service) ListMounts() []string
func (s *service) Reload() error
func (s *service) Close() error
func (s *service) Handler() http.Handler
```
**Handler Implementation**:
- Performs longest-prefix matching to find mount point
- Delegates to mount point's `ServeHTTP()`
- Returns silently if no match (allows API routes to handle)
- Thread-safe with `sync.RWMutex`
### Step 6: Create Configuration Helpers
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/config.go` (NEW)
```go
// Default configurations
func DefaultServiceConfig() *ServiceConfig
func DefaultCachePolicy() CachePolicy
func DefaultMIMEResolver() MIMETypeResolver
// Helper constructors
func LocalProvider(path string) FileSystemProvider
func ZipProvider(zipPath string) FileSystemProvider
func EmbedProvider(embedFS *embed.FS, zipFile string) FileSystemProvider
// Policy constructors
func SimpleCache(seconds int) CachePolicy
func ExtensionCache(rules map[string]int) CachePolicy
func HTMLFallback(indexFile string) FallbackStrategy
func ExtensionFallback(staticExtensions []string) FallbackStrategy
```
### Step 7: Update/Remove Existing Files
**REMOVE**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/staticweb.go`
- All functionality migrated to new interface-based design
- No backward compatibility needed per user request
**KEEP**: `/home/hein/hein/dev/ResolveSpec/pkg/server/zipfs/zipfs.go`
- Still used by `ZipFSProvider`
- Already implements `fs.FS` interface correctly
### Step 8: Create Examples and Tests
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/example_test.go` (NEW)
```go
func ExampleService_basic() { /* Serve local directory */ }
func ExampleService_spa() { /* SPA with fallback */ }
func ExampleService_multiple() { /* Multiple mount points */ }
func ExampleService_zip() { /* Serve from zip file */ }
func ExampleService_embedded() { /* Serve from embedded zip */ }
```
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/service_test.go` (NEW)
- Test mount/unmount operations
- Test longest-prefix matching
- Test concurrent access
- Test resource cleanup
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/providers/providers_test.go` (NEW)
- Test each provider implementation
- Test resource cleanup
- Test error handling
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/testing/mocks.go` (NEW)
- `MockFileSystemProvider` - In-memory file storage
- `MockCachePolicy` - Configurable cache behavior
- `MockMIMEResolver` - Custom MIME mappings
- Test helpers for common scenarios
### Step 9: Create Documentation
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/README.md` (NEW)
Document:
- Quick start examples
- Interface overview
- Provider implementations
- Policy customization
- Router integration patterns
- Migration guide from old API
- Future features roadmap
## Key Improvements Over Current Implementation
### 1. Router-Agnostic Design
**Before**: Coupled to Gorilla mux via `RegisterRoutes(*mux.Router)`
**After**: Returns `http.Handler`, works with any router
```go
// Works with Gorilla Mux
muxRouter.PathPrefix("/").Handler(service.Handler())
// Works with standard http.ServeMux
http.Handle("/", service.Handler())
// Works with any http.Handler-compatible router
```
### 2. Configurable Behaviors
**Before**: Hardcoded MIME types, cache times, file extensions
**After**: Pluggable policies
```go
// Custom cache per file type
cachePolicy := ExtensionCache(map[string]int{
".html": 3600, // 1 hour
".js": 86400, // 1 day
".css": 86400, // 1 day
".png": 604800, // 1 week
})
// Custom fallback logic
fallback := HTMLFallback("index.html")
service.Mount(MountConfig{
URLPrefix: "/",
Provider: LocalProvider("./dist"),
CachePolicy: cachePolicy,
FallbackStrategy: fallback,
})
```
### 3. Better Resource Management
**Before**: Manual zip file cleanup, easy to leak resources
**After**: Proper lifecycle with `Close()` on all components
```go
defer service.Close() // Cleans up all providers
```
### 4. Testability
**Before**: Hard to test, coupled to filesystem
**After**: Mock providers for testing
```go
mockProvider := testing.NewInMemoryProvider(map[string][]byte{
"index.html": []byte("<html>test</html>"),
})
service.Mount(MountConfig{
URLPrefix: "/",
Provider: mockProvider,
})
```
### 5. Extensibility
**Before**: Need to modify code to support new file sources
**After**: Implement `FileSystemProvider` interface
```go
// Future: Add HTTP provider without changing core code
type HTTPFSProvider struct { /* ... */ }
func (h *HTTPFSProvider) Open(name string) (fs.File, error) { /* ... */ }
func (h *HTTPFSProvider) Close() error { /* ... */ }
func (h *HTTPFSProvider) Type() string { return "http" }
```
## Future Features (To Implement Later)
### Remote HTTP/HTTPS Provider
**File**: `providers/http.go` (FUTURE)
Serve static files from remote HTTP servers with local caching:
```go
type HTTPFSProvider struct {
baseURL string
httpClient *http.Client
cache LocalCache // Optional disk/memory cache
cacheTTL time.Duration
mu sync.RWMutex
}
// Example usage
service.Mount(MountConfig{
URLPrefix: "/cdn",
Provider: HTTPProvider("https://cdn.example.com/assets"),
})
```
**Features**:
- Fetch files from remote URLs on-demand
- Local cache to reduce remote requests
- Configurable TTL and cache eviction
- HEAD request support for metadata
- Retry logic and timeout handling
- Support for authentication headers
### S3-Compatible Provider
**File**: `providers/s3.go` (FUTURE)
Serve static files from S3, MinIO, or S3-compatible storage:
```go
type S3FSProvider struct {
bucket string
prefix string
region string
client *s3.Client
cache LocalCache
mu sync.RWMutex
}
// Example usage
service.Mount(MountConfig{
URLPrefix: "/media",
Provider: S3Provider("my-bucket", "static/", "us-east-1"),
})
```
**Features**:
- List and fetch objects from S3 buckets
- Support for AWS S3, MinIO, DigitalOcean Spaces, etc.
- IAM role or credential-based authentication
- Optional local caching layer
- Efficient metadata retrieval
- Support for presigned URLs
### Other Future Providers
- **GitProvider**: Serve files from Git repositories
- **MemoryProvider**: In-memory filesystem for testing/temporary files
- **ProxyProvider**: Proxy to another static file server
- **CompositeProvider**: Fallback chain across multiple providers
### Advanced Cache Policies (FUTURE)
- **TimedCachePolicy**: Different cache times by time of day
- **UserAgentCachePolicy**: Cache based on client type
- **ConditionalCachePolicy**: Smart cache based on file size/type
- **DistributedCachePolicy**: Shared cache across service instances
### Advanced Fallback Strategies (FUTURE)
- **RegexFallbackStrategy**: Pattern-based routing
- **I18nFallbackStrategy**: Language-based file resolution
- **VersionedFallbackStrategy**: A/B testing support
## Critical Files Summary
### Files to CREATE (in order):
1. `pkg/server/staticweb/interfaces.go` - Core contracts
2. `pkg/server/staticweb/config.go` - Configuration structs and helpers
3. `pkg/server/staticweb/policies/cache.go` - Cache policy implementations
4. `pkg/server/staticweb/policies/mime.go` - MIME resolver implementations
5. `pkg/server/staticweb/policies/fallback.go` - Fallback strategy implementations
6. `pkg/server/staticweb/providers/local.go` - Local directory provider
7. `pkg/server/staticweb/providers/zip.go` - Zip file provider
8. `pkg/server/staticweb/providers/embed.go` - Embedded filesystem provider
9. `pkg/server/staticweb/mount.go` - Mount point implementation
10. `pkg/server/staticweb/service.go` - Main service implementation
11. `pkg/server/staticweb/testing/mocks.go` - Test helpers
12. `pkg/server/staticweb/service_test.go` - Service tests
13. `pkg/server/staticweb/providers/providers_test.go` - Provider tests
14. `pkg/server/staticweb/example_test.go` - Example code
15. `pkg/server/staticweb/README.md` - Documentation
### Files to REMOVE:
1. `pkg/server/staticweb/staticweb.go` - Replaced by new design
### Files to KEEP:
1. `pkg/server/zipfs/zipfs.go` - Used by ZipFSProvider
### Files for FUTURE (not in this refactoring):
1. `pkg/server/staticweb/providers/http.go` - HTTP/HTTPS remote provider
2. `pkg/server/staticweb/providers/s3.go` - S3-compatible storage provider
3. `pkg/server/staticweb/providers/composite.go` - Fallback chain provider
## Example Usage After Refactoring
### Basic Static Site
```go
service := staticweb.NewService(nil) // Use defaults
err := service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: staticweb.LocalProvider("./public"),
})
muxRouter.PathPrefix("/").Handler(service.Handler())
```
### SPA with API Routes
```go
service := staticweb.NewService(nil)
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: staticweb.LocalProvider("./dist"),
FallbackStrategy: staticweb.HTMLFallback("index.html"),
})
// API routes take precedence (registered first)
muxRouter.HandleFunc("/api/users", usersHandler)
muxRouter.HandleFunc("/api/posts", postsHandler)
// Static files handle all other routes
muxRouter.PathPrefix("/").Handler(service.Handler())
```
### Multiple Mount Points with Different Policies
```go
service := staticweb.NewService(&staticweb.ServiceConfig{
DefaultCacheTime: 3600,
})
// Assets with long cache
service.Mount(staticweb.MountConfig{
URLPrefix: "/assets",
Provider: staticweb.LocalProvider("./assets"),
CachePolicy: staticweb.SimpleCache(604800), // 1 week
})
// HTML with short cache
service.Mount(staticweb.MountConfig{
URLPrefix: "/",
Provider: staticweb.LocalProvider("./public"),
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
})
router.PathPrefix("/").Handler(service.Handler())
```
### Embedded Files from Zip
```go
//go:embed assets.zip
var assetsZip embed.FS
service := staticweb.NewService(nil)
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: staticweb.EmbedProvider(&assetsZip, "assets.zip"),
})
router.PathPrefix("/").Handler(service.Handler())
```
### Future: CDN Fallback (when HTTP provider is implemented)
```go
// Primary CDN with local fallback
service.Mount(staticweb.MountConfig{
URLPrefix: "/static",
Provider: staticweb.CompositeProvider(
staticweb.HTTPProvider("https://cdn.example.com/assets"),
staticweb.LocalProvider("./public/assets"),
),
})
```
## Testing Strategy
### Unit Tests
- Each provider implementation independently
- Each policy implementation independently
- Mount point request handling
- Service mount/unmount operations
### Integration Tests
- Full request flow through service
- Multiple mount points
- Longest-prefix matching
- Resource cleanup
### Example Tests
- Executable examples in `example_test.go`
- Demonstrate common usage patterns
## Migration Impact
### Breaking Changes
- Complete API redesign (acceptable per user)
- Package not currently used in codebase (no migration needed)
- New consumers will use new API from start
### Future Extensibility
The interface-driven design allows future additions without breaking changes:
- Add HTTPFSProvider by implementing `FileSystemProvider`
- Add S3FSProvider by implementing `FileSystemProvider`
- Add custom cache policies by implementing `CachePolicy`
- Add custom fallback strategies by implementing `FallbackStrategy`
## Implementation Order
1. Interfaces (foundation)
2. Configuration (API surface)
3. Policies (pluggable behavior)
4. Providers (filesystem abstraction)
5. Mount Point (request handling)
6. Service (orchestration)
7. Tests (validation)
8. Documentation (usage)
9. Remove old code (cleanup)
This order ensures each layer builds on tested, working components.
---

View File

@@ -0,0 +1,103 @@
package policies
import (
"fmt"
"path"
"strings"
)
// SimpleCachePolicy implements a basic cache policy with a single TTL for all files.
type SimpleCachePolicy struct {
cacheTime int // Cache duration in seconds
}
// NewSimpleCachePolicy creates a new SimpleCachePolicy with the given cache time in seconds.
func NewSimpleCachePolicy(cacheTimeSeconds int) *SimpleCachePolicy {
return &SimpleCachePolicy{
cacheTime: cacheTimeSeconds,
}
}
// GetCacheTime returns the cache duration for any file.
func (p *SimpleCachePolicy) GetCacheTime(filePath string) int {
return p.cacheTime
}
// GetCacheHeaders returns the Cache-Control header for the given file.
func (p *SimpleCachePolicy) GetCacheHeaders(filePath string) map[string]string {
if p.cacheTime <= 0 {
return map[string]string{
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
}
return map[string]string{
"Cache-Control": fmt.Sprintf("public, max-age=%d", p.cacheTime),
}
}
// ExtensionBasedCachePolicy implements a cache policy that varies by file extension.
type ExtensionBasedCachePolicy struct {
rules map[string]int // Extension -> cache time in seconds
defaultTime int // Default cache time for unmatched extensions
}
// NewExtensionBasedCachePolicy creates a new ExtensionBasedCachePolicy.
// rules maps file extensions (with leading dot, e.g., ".js") to cache times in seconds.
// defaultTime is used for files that don't match any rule.
func NewExtensionBasedCachePolicy(rules map[string]int, defaultTime int) *ExtensionBasedCachePolicy {
return &ExtensionBasedCachePolicy{
rules: rules,
defaultTime: defaultTime,
}
}
// GetCacheTime returns the cache duration based on the file extension.
func (p *ExtensionBasedCachePolicy) GetCacheTime(filePath string) int {
ext := strings.ToLower(path.Ext(filePath))
if cacheTime, ok := p.rules[ext]; ok {
return cacheTime
}
return p.defaultTime
}
// GetCacheHeaders returns cache headers based on the file extension.
func (p *ExtensionBasedCachePolicy) GetCacheHeaders(filePath string) map[string]string {
cacheTime := p.GetCacheTime(filePath)
if cacheTime <= 0 {
return map[string]string{
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
}
return map[string]string{
"Cache-Control": fmt.Sprintf("public, max-age=%d", cacheTime),
}
}
// NoCachePolicy implements a cache policy that disables all caching.
type NoCachePolicy struct{}
// NewNoCachePolicy creates a new NoCachePolicy.
func NewNoCachePolicy() *NoCachePolicy {
return &NoCachePolicy{}
}
// GetCacheTime always returns 0 (no caching).
func (p *NoCachePolicy) GetCacheTime(filePath string) int {
return 0
}
// GetCacheHeaders returns headers that disable caching.
func (p *NoCachePolicy) GetCacheHeaders(filePath string) map[string]string {
return map[string]string{
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
}

View File

@@ -0,0 +1,159 @@
package policies
import (
"path"
"strings"
)
// NoFallback implements a fallback strategy that never falls back.
// All requests for missing files will result in 404 responses.
type NoFallback struct{}
// NewNoFallback creates a new NoFallback strategy.
func NewNoFallback() *NoFallback {
return &NoFallback{}
}
// ShouldFallback always returns false.
func (f *NoFallback) ShouldFallback(filePath string) bool {
return false
}
// GetFallbackPath returns an empty string (never called since ShouldFallback returns false).
func (f *NoFallback) GetFallbackPath(filePath string) string {
return ""
}
// HTMLFallbackStrategy implements a fallback strategy for Single Page Applications (SPAs).
// It serves a specified HTML file (typically index.html) for non-file requests.
type HTMLFallbackStrategy struct {
indexFile string
}
// NewHTMLFallbackStrategy creates a new HTMLFallbackStrategy.
// indexFile is the path to the HTML file to serve (e.g., "index.html", "/index.html").
func NewHTMLFallbackStrategy(indexFile string) *HTMLFallbackStrategy {
return &HTMLFallbackStrategy{
indexFile: indexFile,
}
}
// ShouldFallback returns true for requests that don't look like static assets.
func (f *HTMLFallbackStrategy) ShouldFallback(filePath string) bool {
// Always fall back unless it looks like a static asset
return !f.isStaticAsset(filePath)
}
// GetFallbackPath returns the index file path.
func (f *HTMLFallbackStrategy) GetFallbackPath(filePath string) string {
return f.indexFile
}
// isStaticAsset checks if the path looks like a static asset (has a file extension).
func (f *HTMLFallbackStrategy) isStaticAsset(filePath string) bool {
return path.Ext(filePath) != ""
}
// ExtensionBasedFallback implements a fallback strategy that skips fallback for known static file extensions.
// This is the behavior from the original StaticHTMLFallbackHandler.
type ExtensionBasedFallback struct {
staticExtensions map[string]bool
fallbackPath string
}
// NewExtensionBasedFallback creates a new ExtensionBasedFallback strategy.
// staticExtensions is a list of file extensions (with leading dot) that should NOT use fallback.
// fallbackPath is the file to serve when fallback is triggered.
func NewExtensionBasedFallback(staticExtensions []string, fallbackPath string) *ExtensionBasedFallback {
extMap := make(map[string]bool)
for _, ext := range staticExtensions {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
extMap[strings.ToLower(ext)] = true
}
return &ExtensionBasedFallback{
staticExtensions: extMap,
fallbackPath: fallbackPath,
}
}
// NewDefaultExtensionBasedFallback creates an ExtensionBasedFallback with common web asset extensions.
// This matches the behavior of the original StaticHTMLFallbackHandler.
func NewDefaultExtensionBasedFallback(fallbackPath string) *ExtensionBasedFallback {
return NewExtensionBasedFallback([]string{
".js", ".css", ".png", ".svg", ".ico", ".json",
".jpg", ".jpeg", ".gif", ".woff", ".woff2", ".ttf", ".eot",
}, fallbackPath)
}
// ShouldFallback returns true if the file path doesn't have a static asset extension.
func (f *ExtensionBasedFallback) ShouldFallback(filePath string) bool {
ext := strings.ToLower(path.Ext(filePath))
// If it's a known static extension, don't fallback
if f.staticExtensions[ext] {
return false
}
// Otherwise, try fallback
return true
}
// GetFallbackPath returns the configured fallback path.
func (f *ExtensionBasedFallback) GetFallbackPath(filePath string) string {
return f.fallbackPath
}
// HTMLExtensionFallback implements a fallback strategy that appends .html to paths.
// This tries to serve {path}.html for missing files.
type HTMLExtensionFallback struct {
staticExtensions map[string]bool
}
// NewHTMLExtensionFallback creates a new HTMLExtensionFallback strategy.
func NewHTMLExtensionFallback(staticExtensions []string) *HTMLExtensionFallback {
extMap := make(map[string]bool)
for _, ext := range staticExtensions {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
extMap[strings.ToLower(ext)] = true
}
return &HTMLExtensionFallback{
staticExtensions: extMap,
}
}
// ShouldFallback returns true if the path doesn't have a static extension or .html.
func (f *HTMLExtensionFallback) ShouldFallback(filePath string) bool {
ext := strings.ToLower(path.Ext(filePath))
// If it's a known static extension, don't fallback
if f.staticExtensions[ext] {
return false
}
// If it already has .html, don't fallback
if ext == ".html" || ext == ".htm" {
return false
}
return true
}
// GetFallbackPath returns the path with .html appended.
func (f *HTMLExtensionFallback) GetFallbackPath(filePath string) string {
cleanPath := path.Clean(filePath)
if !strings.HasSuffix(filePath, "/") {
cleanPath = strings.TrimRight(cleanPath, "/")
}
if !strings.HasSuffix(strings.ToLower(cleanPath), ".html") {
return cleanPath + ".html"
}
return cleanPath
}

View File

@@ -0,0 +1,245 @@
package policies
import (
"mime"
"path"
"strings"
"sync"
)
// DefaultMIMEResolver implements a MIME type resolver using Go's standard mime package
// and a set of common web file type mappings.
type DefaultMIMEResolver struct {
customTypes map[string]string
mu sync.RWMutex
}
// NewDefaultMIMEResolver creates a new DefaultMIMEResolver with common web MIME types.
func NewDefaultMIMEResolver() *DefaultMIMEResolver {
resolver := &DefaultMIMEResolver{
customTypes: make(map[string]string),
}
// JavaScript & TypeScript
resolver.RegisterMIMEType(".js", "application/javascript")
resolver.RegisterMIMEType(".mjs", "application/javascript")
resolver.RegisterMIMEType(".cjs", "application/javascript")
resolver.RegisterMIMEType(".ts", "text/typescript")
resolver.RegisterMIMEType(".tsx", "text/tsx")
resolver.RegisterMIMEType(".jsx", "text/jsx")
// CSS & Styling
resolver.RegisterMIMEType(".css", "text/css")
resolver.RegisterMIMEType(".scss", "text/x-scss")
resolver.RegisterMIMEType(".sass", "text/x-sass")
resolver.RegisterMIMEType(".less", "text/x-less")
// HTML & XML
resolver.RegisterMIMEType(".html", "text/html")
resolver.RegisterMIMEType(".htm", "text/html")
resolver.RegisterMIMEType(".xml", "application/xml")
resolver.RegisterMIMEType(".xhtml", "application/xhtml+xml")
// Images - Raster
resolver.RegisterMIMEType(".png", "image/png")
resolver.RegisterMIMEType(".jpg", "image/jpeg")
resolver.RegisterMIMEType(".jpeg", "image/jpeg")
resolver.RegisterMIMEType(".gif", "image/gif")
resolver.RegisterMIMEType(".webp", "image/webp")
resolver.RegisterMIMEType(".avif", "image/avif")
resolver.RegisterMIMEType(".bmp", "image/bmp")
resolver.RegisterMIMEType(".tiff", "image/tiff")
resolver.RegisterMIMEType(".tif", "image/tiff")
resolver.RegisterMIMEType(".ico", "image/x-icon")
resolver.RegisterMIMEType(".cur", "image/x-icon")
// Images - Vector
resolver.RegisterMIMEType(".svg", "image/svg+xml")
resolver.RegisterMIMEType(".svgz", "image/svg+xml")
// Fonts
resolver.RegisterMIMEType(".woff", "font/woff")
resolver.RegisterMIMEType(".woff2", "font/woff2")
resolver.RegisterMIMEType(".ttf", "font/ttf")
resolver.RegisterMIMEType(".otf", "font/otf")
resolver.RegisterMIMEType(".eot", "application/vnd.ms-fontobject")
// Audio
resolver.RegisterMIMEType(".mp3", "audio/mpeg")
resolver.RegisterMIMEType(".wav", "audio/wav")
resolver.RegisterMIMEType(".ogg", "audio/ogg")
resolver.RegisterMIMEType(".oga", "audio/ogg")
resolver.RegisterMIMEType(".m4a", "audio/mp4")
resolver.RegisterMIMEType(".aac", "audio/aac")
resolver.RegisterMIMEType(".flac", "audio/flac")
resolver.RegisterMIMEType(".opus", "audio/opus")
resolver.RegisterMIMEType(".weba", "audio/webm")
// Video
resolver.RegisterMIMEType(".mp4", "video/mp4")
resolver.RegisterMIMEType(".webm", "video/webm")
resolver.RegisterMIMEType(".ogv", "video/ogg")
resolver.RegisterMIMEType(".avi", "video/x-msvideo")
resolver.RegisterMIMEType(".mpeg", "video/mpeg")
resolver.RegisterMIMEType(".mpg", "video/mpeg")
resolver.RegisterMIMEType(".mov", "video/quicktime")
resolver.RegisterMIMEType(".wmv", "video/x-ms-wmv")
resolver.RegisterMIMEType(".flv", "video/x-flv")
resolver.RegisterMIMEType(".mkv", "video/x-matroska")
resolver.RegisterMIMEType(".m4v", "video/mp4")
// Data & Configuration
resolver.RegisterMIMEType(".json", "application/json")
resolver.RegisterMIMEType(".xml", "application/xml")
resolver.RegisterMIMEType(".yml", "application/yaml")
resolver.RegisterMIMEType(".yaml", "application/yaml")
resolver.RegisterMIMEType(".toml", "application/toml")
resolver.RegisterMIMEType(".ini", "text/plain")
resolver.RegisterMIMEType(".conf", "text/plain")
resolver.RegisterMIMEType(".config", "text/plain")
// Documents
resolver.RegisterMIMEType(".pdf", "application/pdf")
resolver.RegisterMIMEType(".doc", "application/msword")
resolver.RegisterMIMEType(".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
resolver.RegisterMIMEType(".xls", "application/vnd.ms-excel")
resolver.RegisterMIMEType(".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
resolver.RegisterMIMEType(".ppt", "application/vnd.ms-powerpoint")
resolver.RegisterMIMEType(".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
resolver.RegisterMIMEType(".odt", "application/vnd.oasis.opendocument.text")
resolver.RegisterMIMEType(".ods", "application/vnd.oasis.opendocument.spreadsheet")
resolver.RegisterMIMEType(".odp", "application/vnd.oasis.opendocument.presentation")
// Archives
resolver.RegisterMIMEType(".zip", "application/zip")
resolver.RegisterMIMEType(".tar", "application/x-tar")
resolver.RegisterMIMEType(".gz", "application/gzip")
resolver.RegisterMIMEType(".bz2", "application/x-bzip2")
resolver.RegisterMIMEType(".7z", "application/x-7z-compressed")
resolver.RegisterMIMEType(".rar", "application/vnd.rar")
// Text files
resolver.RegisterMIMEType(".txt", "text/plain")
resolver.RegisterMIMEType(".md", "text/markdown")
resolver.RegisterMIMEType(".markdown", "text/markdown")
resolver.RegisterMIMEType(".csv", "text/csv")
resolver.RegisterMIMEType(".log", "text/plain")
// Source code (for syntax highlighting in browsers)
resolver.RegisterMIMEType(".c", "text/x-c")
resolver.RegisterMIMEType(".cpp", "text/x-c++")
resolver.RegisterMIMEType(".h", "text/x-c")
resolver.RegisterMIMEType(".hpp", "text/x-c++")
resolver.RegisterMIMEType(".go", "text/x-go")
resolver.RegisterMIMEType(".py", "text/x-python")
resolver.RegisterMIMEType(".java", "text/x-java")
resolver.RegisterMIMEType(".rs", "text/x-rust")
resolver.RegisterMIMEType(".rb", "text/x-ruby")
resolver.RegisterMIMEType(".php", "text/x-php")
resolver.RegisterMIMEType(".sh", "text/x-shellscript")
resolver.RegisterMIMEType(".bash", "text/x-shellscript")
resolver.RegisterMIMEType(".sql", "text/x-sql")
resolver.RegisterMIMEType(".template.sql", "text/plain")
resolver.RegisterMIMEType(".upg", "text/plain")
// Web Assembly
resolver.RegisterMIMEType(".wasm", "application/wasm")
// Manifest & Service Worker
resolver.RegisterMIMEType(".webmanifest", "application/manifest+json")
resolver.RegisterMIMEType(".manifest", "text/cache-manifest")
// 3D Models
resolver.RegisterMIMEType(".gltf", "model/gltf+json")
resolver.RegisterMIMEType(".glb", "model/gltf-binary")
resolver.RegisterMIMEType(".obj", "model/obj")
resolver.RegisterMIMEType(".stl", "model/stl")
// Other common web assets
resolver.RegisterMIMEType(".map", "application/json") // Source maps
resolver.RegisterMIMEType(".swf", "application/x-shockwave-flash")
resolver.RegisterMIMEType(".apk", "application/vnd.android.package-archive")
resolver.RegisterMIMEType(".dmg", "application/x-apple-diskimage")
resolver.RegisterMIMEType(".exe", "application/x-msdownload")
resolver.RegisterMIMEType(".iso", "application/x-iso9660-image")
return resolver
}
// GetMIMEType returns the MIME type for the given file path.
// It first checks custom registered types, then falls back to Go's mime.TypeByExtension.
func (r *DefaultMIMEResolver) GetMIMEType(filePath string) string {
ext := strings.ToLower(path.Ext(filePath))
// Check custom types first
r.mu.RLock()
if mimeType, ok := r.customTypes[ext]; ok {
r.mu.RUnlock()
return mimeType
}
r.mu.RUnlock()
// Fall back to standard library
if mimeType := mime.TypeByExtension(ext); mimeType != "" {
return mimeType
}
// Return empty string if unknown
return ""
}
// RegisterMIMEType registers a custom MIME type for the given file extension.
func (r *DefaultMIMEResolver) RegisterMIMEType(extension, mimeType string) {
if !strings.HasPrefix(extension, ".") {
extension = "." + extension
}
r.mu.Lock()
r.customTypes[strings.ToLower(extension)] = mimeType
r.mu.Unlock()
}
// ConfigurableMIMEResolver implements a MIME type resolver with user-defined mappings only.
// It does not use any default mappings.
type ConfigurableMIMEResolver struct {
types map[string]string
mu sync.RWMutex
}
// NewConfigurableMIMEResolver creates a new ConfigurableMIMEResolver with the given mappings.
func NewConfigurableMIMEResolver(types map[string]string) *ConfigurableMIMEResolver {
resolver := &ConfigurableMIMEResolver{
types: make(map[string]string),
}
for ext, mimeType := range types {
resolver.RegisterMIMEType(ext, mimeType)
}
return resolver
}
// GetMIMEType returns the MIME type for the given file path.
func (r *ConfigurableMIMEResolver) GetMIMEType(filePath string) string {
ext := strings.ToLower(path.Ext(filePath))
r.mu.RLock()
defer r.mu.RUnlock()
if mimeType, ok := r.types[ext]; ok {
return mimeType
}
return ""
}
// RegisterMIMEType registers a MIME type for the given file extension.
func (r *ConfigurableMIMEResolver) RegisterMIMEType(extension, mimeType string) {
if !strings.HasPrefix(extension, ".") {
extension = "." + extension
}
r.mu.Lock()
r.types[strings.ToLower(extension)] = mimeType
r.mu.Unlock()
}

View File

@@ -0,0 +1,119 @@
package providers
import (
"archive/zip"
"bytes"
"embed"
"fmt"
"io"
"io/fs"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/server/zipfs"
)
// EmbedFSProvider serves files from an embedded filesystem.
// It supports both direct embedded directories and embedded zip files.
type EmbedFSProvider struct {
embedFS *embed.FS
zipFile string // Optional: path within embedded FS to zip file
zipReader *zip.Reader
fs fs.FS
mu sync.RWMutex
}
// NewEmbedFSProvider creates a new EmbedFSProvider.
// If zipFile is empty, the embedded FS is used directly.
// If zipFile is specified, it's treated as a path to a zip file within the embedded FS.
func NewEmbedFSProvider(embedFS fs.FS, zipFile string) (*EmbedFSProvider, error) {
if embedFS == nil {
return nil, fmt.Errorf("embedded filesystem cannot be nil")
}
// Try to cast to *embed.FS for tracking purposes
var embedFSPtr *embed.FS
if efs, ok := embedFS.(*embed.FS); ok {
embedFSPtr = efs
}
provider := &EmbedFSProvider{
embedFS: embedFSPtr,
zipFile: zipFile,
}
// If zipFile is specified, open it as a zip archive
if zipFile != "" {
// Read the zip file from the embedded FS
// We need to check if the FS supports ReadFile
var data []byte
var err error
if readFileFS, ok := embedFS.(interface{ ReadFile(string) ([]byte, error) }); ok {
data, err = readFileFS.ReadFile(zipFile)
} else {
// Fall back to Open and reading
file, openErr := embedFS.Open(zipFile)
if openErr != nil {
return nil, fmt.Errorf("failed to open embedded zip file: %w", openErr)
}
defer file.Close()
data, err = io.ReadAll(file)
}
if err != nil {
return nil, fmt.Errorf("failed to read embedded zip file: %w", err)
}
// Create a zip reader from the data
reader := bytes.NewReader(data)
zipReader, err := zip.NewReader(reader, int64(len(data)))
if err != nil {
return nil, fmt.Errorf("failed to create zip reader: %w", err)
}
provider.zipReader = zipReader
provider.fs = zipfs.NewZipFS(zipReader)
} else {
// Use the embedded FS directly
provider.fs = embedFS
}
return provider, nil
}
// Open opens the named file from the embedded filesystem.
func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
p.mu.RLock()
defer p.mu.RUnlock()
if p.fs == nil {
return nil, fmt.Errorf("embedded filesystem is closed")
}
return p.fs.Open(name)
}
// Close releases any resources held by the provider.
// For embedded filesystems, this is mostly a no-op since Go manages the lifecycle.
func (p *EmbedFSProvider) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
// Clear references to allow garbage collection
p.fs = nil
p.zipReader = nil
return nil
}
// Type returns "embed" or "embed-zip" depending on the configuration.
func (p *EmbedFSProvider) Type() string {
if p.zipFile != "" {
return "embed-zip"
}
return "embed"
}
// ZipFile returns the path to the zip file within the embedded FS, if any.
func (p *EmbedFSProvider) ZipFile() string {
return p.zipFile
}

View File

@@ -0,0 +1,80 @@
package providers
import (
"fmt"
"io/fs"
"os"
"path/filepath"
)
// LocalFSProvider serves files from a local directory.
type LocalFSProvider struct {
path string
fs fs.FS
}
// NewLocalFSProvider creates a new LocalFSProvider for the given directory path.
// The path must be an absolute path to an existing directory.
func NewLocalFSProvider(path string) (*LocalFSProvider, error) {
// Validate that the path exists and is a directory
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("failed to stat directory: %w", err)
}
if !info.IsDir() {
return nil, fmt.Errorf("path is not a directory: %s", path)
}
// Convert to absolute path
absPath, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
return &LocalFSProvider{
path: absPath,
fs: os.DirFS(absPath),
}, nil
}
// Open opens the named file from the local directory.
func (p *LocalFSProvider) Open(name string) (fs.File, error) {
return p.fs.Open(name)
}
// Close releases any resources held by the provider.
// For local filesystem, this is a no-op since os.DirFS doesn't hold resources.
func (p *LocalFSProvider) Close() error {
return nil
}
// Type returns "local".
func (p *LocalFSProvider) Type() string {
return "local"
}
// Path returns the absolute path to the directory being served.
func (p *LocalFSProvider) Path() string {
return p.path
}
// Reload refreshes the filesystem view.
// For local directories, os.DirFS automatically picks up changes,
// so this recreates the DirFS to ensure a fresh view.
func (p *LocalFSProvider) Reload() error {
// Verify the directory still exists
info, err := os.Stat(p.path)
if err != nil {
return fmt.Errorf("failed to stat directory: %w", err)
}
if !info.IsDir() {
return fmt.Errorf("path is no longer a directory: %s", p.path)
}
// Recreate the DirFS
p.fs = os.DirFS(p.path)
return nil
}

View File

@@ -0,0 +1,393 @@
package providers
import (
"archive/zip"
"bytes"
"io"
"io/fs"
"os"
"path/filepath"
"testing"
"time"
)
func TestLocalFSProvider(t *testing.T) {
// Create a temporary directory with test files
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
if err := os.WriteFile(testFile, []byte("test content"), 0644); err != nil {
t.Fatal(err)
}
provider, err := NewLocalFSProvider(tmpDir)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Test opening a file
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer file.Close()
// Read the file
data, err := io.ReadAll(file)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(data) != "test content" {
t.Errorf("Expected 'test content', got %q", string(data))
}
// Test type
if provider.Type() != "local" {
t.Errorf("Expected type 'local', got %q", provider.Type())
}
}
func TestZipFSProvider(t *testing.T) {
// Create a temporary zip file
tmpDir := t.TempDir()
zipPath := filepath.Join(tmpDir, "test.zip")
// Create zip file with test content
zipFile, err := os.Create(zipPath)
if err != nil {
t.Fatal(err)
}
zipWriter := zip.NewWriter(zipFile)
fileWriter, err := zipWriter.Create("test.txt")
if err != nil {
t.Fatal(err)
}
_, err = fileWriter.Write([]byte("zip content"))
if err != nil {
t.Fatal(err)
}
if err := zipWriter.Close(); err != nil {
t.Fatal(err)
}
if err := zipFile.Close(); err != nil {
t.Fatal(err)
}
// Test the provider
provider, err := NewZipFSProvider(zipPath)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Test opening a file
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer file.Close()
// Read the file
data, err := io.ReadAll(file)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(data) != "zip content" {
t.Errorf("Expected 'zip content', got %q", string(data))
}
// Test type
if provider.Type() != "zip" {
t.Errorf("Expected type 'zip', got %q", provider.Type())
}
}
func TestZipFSProviderReload(t *testing.T) {
// Create a temporary zip file
tmpDir := t.TempDir()
zipPath := filepath.Join(tmpDir, "test.zip")
// Helper to create zip with content
createZip := func(content string) {
zipFile, err := os.Create(zipPath)
if err != nil {
t.Fatal(err)
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
fileWriter, err := zipWriter.Create("test.txt")
if err != nil {
t.Fatal(err)
}
_, err = fileWriter.Write([]byte(content))
if err != nil {
t.Fatal(err)
}
if err := zipWriter.Close(); err != nil {
t.Fatal(err)
}
}
// Create initial zip
createZip("original content")
// Test the provider
provider, err := NewZipFSProvider(zipPath)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Read initial content
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
data, _ := io.ReadAll(file)
file.Close()
if string(data) != "original content" {
t.Errorf("Expected 'original content', got %q", string(data))
}
// Update the zip file
createZip("updated content")
// Reload the provider
if err := provider.Reload(); err != nil {
t.Fatalf("Failed to reload: %v", err)
}
// Read updated content
file, err = provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file after reload: %v", err)
}
data, _ = io.ReadAll(file)
file.Close()
if string(data) != "updated content" {
t.Errorf("Expected 'updated content', got %q", string(data))
}
}
func TestLocalFSProviderReload(t *testing.T) {
// Create a temporary directory with test files
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
if err := os.WriteFile(testFile, []byte("original"), 0644); err != nil {
t.Fatal(err)
}
provider, err := NewLocalFSProvider(tmpDir)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Read initial content
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
data, _ := io.ReadAll(file)
file.Close()
if string(data) != "original" {
t.Errorf("Expected 'original', got %q", string(data))
}
// Update the file
if err := os.WriteFile(testFile, []byte("updated"), 0644); err != nil {
t.Fatal(err)
}
// Reload the provider
if err := provider.Reload(); err != nil {
t.Fatalf("Failed to reload: %v", err)
}
// Read updated content
file, err = provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file after reload: %v", err)
}
data, _ = io.ReadAll(file)
file.Close()
if string(data) != "updated" {
t.Errorf("Expected 'updated', got %q", string(data))
}
}
func TestEmbedFSProvider(t *testing.T) {
// Test with a mock embed.FS
mockFS := &mockEmbedFS{
files: map[string][]byte{
"test.txt": []byte("test content"),
},
}
provider, err := NewEmbedFSProvider(mockFS, "")
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Test type
if provider.Type() != "embed" {
t.Errorf("Expected type 'embed', got %q", provider.Type())
}
// Test opening a file
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer file.Close()
// Read the file
data, err := io.ReadAll(file)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(data) != "test content" {
t.Errorf("Expected 'test content', got %q", string(data))
}
}
func TestEmbedFSProviderWithZip(t *testing.T) {
// Create an embedded-like FS with a zip file
// For simplicity, we'll use a mock embed.FS
tmpDir := t.TempDir()
zipPath := filepath.Join(tmpDir, "test.zip")
// Create zip file
zipFile, err := os.Create(zipPath)
if err != nil {
t.Fatal(err)
}
zipWriter := zip.NewWriter(zipFile)
fileWriter, err := zipWriter.Create("test.txt")
if err != nil {
t.Fatal(err)
}
_, err = fileWriter.Write([]byte("embedded zip content"))
if err != nil {
t.Fatal(err)
}
zipWriter.Close()
zipFile.Close()
// Read the zip file
zipData, err := os.ReadFile(zipPath)
if err != nil {
t.Fatal(err)
}
// Create a mock embed.FS
mockFS := &mockEmbedFS{
files: map[string][]byte{
"test.zip": zipData,
},
}
provider, err := NewEmbedFSProvider(mockFS, "test.zip")
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
// Test opening a file
file, err := provider.Open("test.txt")
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer file.Close()
// Read the file
data, err := io.ReadAll(file)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(data) != "embedded zip content" {
t.Errorf("Expected 'embedded zip content', got %q", string(data))
}
// Test type
if provider.Type() != "embed-zip" {
t.Errorf("Expected type 'embed-zip', got %q", provider.Type())
}
}
// mockEmbedFS is a mock embed.FS for testing
type mockEmbedFS struct {
files map[string][]byte
}
func (m *mockEmbedFS) Open(name string) (fs.File, error) {
data, ok := m.files[name]
if !ok {
return nil, os.ErrNotExist
}
return &mockFile{
name: name,
reader: bytes.NewReader(data),
size: int64(len(data)),
}, nil
}
func (m *mockEmbedFS) ReadFile(name string) ([]byte, error) {
data, ok := m.files[name]
if !ok {
return nil, os.ErrNotExist
}
return data, nil
}
type mockFile struct {
name string
reader *bytes.Reader
size int64
}
func (f *mockFile) Stat() (fs.FileInfo, error) {
return &mockFileInfo{name: f.name, size: f.size}, nil
}
func (f *mockFile) Read(p []byte) (int, error) {
return f.reader.Read(p)
}
func (f *mockFile) Close() error {
return nil
}
type mockFileInfo struct {
name string
size int64
}
func (fi *mockFileInfo) Name() string { return fi.name }
func (fi *mockFileInfo) Size() int64 { return fi.size }
func (fi *mockFileInfo) Mode() fs.FileMode { return 0644 }
func (fi *mockFileInfo) ModTime() time.Time { return time.Now() }
func (fi *mockFileInfo) IsDir() bool { return false }
func (fi *mockFileInfo) Sys() interface{} { return nil }

View File

@@ -0,0 +1,102 @@
package providers
import (
"archive/zip"
"fmt"
"io/fs"
"path/filepath"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/server/zipfs"
)
// ZipFSProvider serves files from a zip file.
type ZipFSProvider struct {
zipPath string
zipReader *zip.ReadCloser
zipFS *zipfs.ZipFS
mu sync.RWMutex
}
// NewZipFSProvider creates a new ZipFSProvider for the given zip file path.
func NewZipFSProvider(zipPath string) (*ZipFSProvider, error) {
// Convert to absolute path
absPath, err := filepath.Abs(zipPath)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
// Open the zip file
zipReader, err := zip.OpenReader(absPath)
if err != nil {
return nil, fmt.Errorf("failed to open zip file: %w", err)
}
return &ZipFSProvider{
zipPath: absPath,
zipReader: zipReader,
zipFS: zipfs.NewZipFS(&zipReader.Reader),
}, nil
}
// Open opens the named file from the zip archive.
func (p *ZipFSProvider) Open(name string) (fs.File, error) {
p.mu.RLock()
defer p.mu.RUnlock()
if p.zipFS == nil {
return nil, fmt.Errorf("zip filesystem is closed")
}
return p.zipFS.Open(name)
}
// Close releases resources held by the zip reader.
func (p *ZipFSProvider) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.zipReader != nil {
err := p.zipReader.Close()
p.zipReader = nil
p.zipFS = nil
return err
}
return nil
}
// Type returns "zip".
func (p *ZipFSProvider) Type() string {
return "zip"
}
// Path returns the absolute path to the zip file being served.
func (p *ZipFSProvider) Path() string {
return p.zipPath
}
// Reload reopens the zip file to pick up any changes.
// This is useful in development when the zip file is updated.
func (p *ZipFSProvider) Reload() error {
p.mu.Lock()
defer p.mu.Unlock()
// Close the existing zip reader if open
if p.zipReader != nil {
if err := p.zipReader.Close(); err != nil {
return fmt.Errorf("failed to close old zip reader: %w", err)
}
}
// Reopen the zip file
zipReader, err := zip.OpenReader(p.zipPath)
if err != nil {
return fmt.Errorf("failed to reopen zip file: %w", err)
}
p.zipReader = zipReader
p.zipFS = zipfs.NewZipFS(&zipReader.Reader)
return nil
}

Some files were not shown because too many files have changed in this diff Show More