mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-03 10:24:26 +00:00
Compare commits
114 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebd03d10ad | ||
|
|
4ee6ef0955 | ||
|
|
6f05f15ff6 | ||
|
|
443a672fcb | ||
|
|
c2fcc5aaff | ||
|
|
6664a4e2d2 | ||
| 037bd4c05e | |||
|
|
e77468a239 | ||
|
|
82d84435f2 | ||
|
|
b99b08430e | ||
|
|
fae9a082bd | ||
|
|
191822b91c | ||
|
|
a6a17d019f | ||
|
|
a7cc42044b | ||
|
|
8cdc353029 | ||
|
|
6528e94297 | ||
|
|
f711bf38d2 | ||
|
|
44356d8750 | ||
|
|
caf85cf558 | ||
|
|
2e1547ec65 | ||
|
|
49cdc6f17b | ||
|
|
0bd653820c | ||
|
|
9209193157 | ||
|
|
b8c44c5a99 | ||
|
|
28fd88fff1 | ||
|
|
be38341383 | ||
|
|
fab744b878 | ||
|
|
5ad2bd3a78 | ||
|
|
333fe158e9 | ||
|
|
2a2d351ad4 | ||
|
|
e918c49b84 | ||
|
|
41e4956510 | ||
|
|
8e8c3c6de6 | ||
|
|
aa9b7312f6 | ||
|
|
dca43b0e05 | ||
|
|
6f368bbce5 | ||
|
|
8704cee941 | ||
|
|
4ce5afe0ac | ||
|
|
7b98ea2145 | ||
|
|
897cb2ae0d | ||
|
|
01420e6b63 | ||
|
|
645907d355 | ||
|
|
e81d7b48cc | ||
|
|
8f5a725a09 | ||
|
|
3d5d7b788e | ||
|
|
eaecef686e | ||
|
|
e0d21b17ec | ||
|
|
7e1718e864 | ||
|
|
16d416030e | ||
|
|
bf8500714a | ||
|
|
4f8edd6469 | ||
|
|
ccf8522f88 | ||
|
|
92a83e9cc6 | ||
|
|
4cb35a78b0 | ||
|
|
e10e2e1c27 | ||
|
|
64f56325d4 | ||
|
|
5e6032c91d | ||
|
|
bc2fdc143b | ||
|
|
267e84fd84 | ||
|
|
8adc386863 | ||
|
|
feb023ec48 | ||
|
|
de50141a04 | ||
|
|
c226dc349f | ||
|
|
d4a6f9c4c2 | ||
| 8f83e8fdc1 | |||
|
|
90df4a157c | ||
|
|
2dd404af96 | ||
|
|
17c472b206 | ||
|
|
ed67caf055 | ||
| 4d1b8b6982 | |||
|
|
63ed62a9a3 | ||
|
|
0525323a47 | ||
|
|
c3443f702e | ||
|
|
45c463c117 | ||
|
|
84d673ce14 | ||
|
|
02fbdbd651 | ||
|
|
97988e3b5e | ||
|
|
c9838ad9d2 | ||
|
|
c5c0608f63 | ||
|
|
39c3f05d21 | ||
|
|
4ecd1ac17e | ||
|
|
2b1aea0338 | ||
|
|
1e749efeb3 | ||
|
|
09be676096 | ||
|
|
e8350a70be | ||
|
|
5937b9eab5 | ||
|
|
7c861c708e | ||
|
|
77f39af2f9 | ||
|
|
fbc1471581 | ||
|
|
9351093e2a | ||
|
|
932f12ab0a | ||
|
|
1b2b0d8f0b | ||
|
|
b22792bad6 | ||
|
|
e8111c01aa | ||
|
|
5862016031 | ||
|
|
2f18dde29c | ||
|
|
31ad217818 | ||
|
|
7ef1d6424a | ||
|
|
c50eeac5bf | ||
|
|
6d88f2668a | ||
|
|
8a9423df6d | ||
|
|
4cc943b9d3 | ||
|
|
68dee78a34 | ||
|
|
efb9e5d9d5 | ||
|
|
490ae37c6d | ||
|
|
99307e31e6 | ||
|
|
e3f7869c6d | ||
|
|
c696d502c5 | ||
|
|
4ed1fba6ad | ||
|
|
1d0407a16d | ||
|
|
99001c749d | ||
|
|
1f7a57f8e3 | ||
|
|
a95c28a0bf | ||
|
|
e1abd5ebc1 |
82
.github/workflows/make_tag.yml
vendored
Normal file
82
.github/workflows/make_tag.yml
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
# This workflow will build a golang project
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
|
||||
|
||||
name: Create Go Release (Tag Versioning)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
semver:
|
||||
description: "New Version"
|
||||
required: true
|
||||
default: "patch"
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
jobs:
|
||||
tag_and_commit:
|
||||
name: "Tag and Commit ${{ github.event.inputs.semver }}"
|
||||
runs-on: linux
|
||||
permissions:
|
||||
contents: write # 'write' access to repository contents
|
||||
pull-requests: write # 'write' access to pull requests
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Git
|
||||
run: |
|
||||
git config --global user.name "Hein"
|
||||
git config --global user.email "hein.puth@gmail.com"
|
||||
|
||||
- name: Fetch latest tag
|
||||
id: latest_tag
|
||||
run: |
|
||||
git fetch --tags
|
||||
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
|
||||
echo "::set-output name=tag::$latest_tag"
|
||||
|
||||
- name: Determine new tag version
|
||||
id: new_tag
|
||||
run: |
|
||||
current_tag=${{ steps.latest_tag.outputs.tag }}
|
||||
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
|
||||
IFS='.' read -r -a version_parts <<< "$version"
|
||||
major=${version_parts[0]}
|
||||
minor=${version_parts[1]}
|
||||
patch=${version_parts[2]}
|
||||
case "${{ github.event.inputs.semver }}" in
|
||||
"patch")
|
||||
((patch++))
|
||||
;;
|
||||
"minor")
|
||||
((minor++))
|
||||
patch=0
|
||||
;;
|
||||
"release")
|
||||
((major++))
|
||||
minor=0
|
||||
patch=0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid semver input"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
new_tag="v$major.$minor.$patch"
|
||||
echo "::set-output name=tag::$new_tag"
|
||||
|
||||
- name: Create tag
|
||||
run: |
|
||||
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
|
||||
force: true
|
||||
tags: true
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -17,11 +17,13 @@ jobs:
|
||||
- name: Run unit tests
|
||||
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
|
||||
- name: Generate coverage report
|
||||
continue-on-error: true
|
||||
run: |
|
||||
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
- name: Upload coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
name: coverage-report
|
||||
path: coverage.html
|
||||
@@ -55,27 +57,34 @@ jobs:
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
|
||||
- name: Run resolvespec integration tests
|
||||
continue-on-error: true
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
|
||||
- name: Run restheadspec integration tests
|
||||
continue-on-error: true
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
|
||||
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
|
||||
- name: Generate integration coverage
|
||||
continue-on-error: true
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||
run: |
|
||||
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
|
||||
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
|
||||
|
||||
- name: Upload resolvespec integration coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
name: resolvespec-integration-coverage-report
|
||||
path: coverage-resolvespec-integration.html
|
||||
|
||||
- name: Upload restheadspec integration coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
continue-on-error: true
|
||||
|
||||
with:
|
||||
name: integration-coverage-restheadspec-report
|
||||
path: coverage-restheadspec-integration
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -25,3 +25,4 @@ go.work.sum
|
||||
.env
|
||||
bin/
|
||||
test.db
|
||||
/testserver
|
||||
|
||||
@@ -71,35 +71,18 @@
|
||||
},
|
||||
"gocritic": {
|
||||
"enabled-checks": [
|
||||
"appendAssign",
|
||||
"assignOp",
|
||||
"boolExprSimplify",
|
||||
"builtinShadow",
|
||||
"captLocal",
|
||||
"caseOrder",
|
||||
"defaultCaseOrder",
|
||||
"dupArg",
|
||||
"dupBranchBody",
|
||||
"dupCase",
|
||||
"dupSubExpr",
|
||||
"elseif",
|
||||
"emptyFallthrough",
|
||||
"equalFold",
|
||||
"flagName",
|
||||
"indexAlloc",
|
||||
"initClause",
|
||||
"methodExprCall",
|
||||
"nilValReturn",
|
||||
"rangeExprCopy",
|
||||
"rangeValCopy",
|
||||
"regexpMust",
|
||||
"singleCaseSwitch",
|
||||
"sloppyLen",
|
||||
"stringXbytes",
|
||||
"switchTrue",
|
||||
"typeAssertChain",
|
||||
"typeSwitchVar",
|
||||
"underef",
|
||||
"unlabelStmt",
|
||||
"unnamedResult",
|
||||
"unnecessaryBlock",
|
||||
|
||||
8
.vscode/settings.json
vendored
8
.vscode/settings.json
vendored
@@ -52,5 +52,9 @@
|
||||
"upgrade_dependency": true,
|
||||
"vendor": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"conventionalCommits.scopes": [
|
||||
"spectypes",
|
||||
"dbmanager"
|
||||
]
|
||||
}
|
||||
14
.vscode/tasks.json
vendored
14
.vscode/tasks.json
vendored
@@ -230,7 +230,17 @@
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: lint workspace (fix)",
|
||||
"command": "golangci-lint run --timeout=5m --fix",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
@@ -275,4 +285,4 @@
|
||||
"command": "sh ${workspaceFolder}/make_release.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
86
LICENSE
86
LICENSE
@@ -1,21 +1,73 @@
|
||||
MIT License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Copyright (c) 2025
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Definitions.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||
|
||||
Copyright 2025 wdevs
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
59
Makefile
59
Makefile
@@ -13,10 +13,63 @@ test-integration:
|
||||
# Run all tests (unit + integration)
|
||||
test: test-unit test-integration
|
||||
|
||||
release-version: ## Create and push a release with specific version (use: make release-version VERSION=v1.2.3 or make release-version to auto-increment)
|
||||
@if [ -z "$(VERSION)" ]; then \
|
||||
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0"); \
|
||||
echo "No VERSION specified. Last version: $$latest_tag"; \
|
||||
version_num=$$(echo "$$latest_tag" | sed 's/^v//'); \
|
||||
major=$$(echo "$$version_num" | cut -d. -f1); \
|
||||
minor=$$(echo "$$version_num" | cut -d. -f2); \
|
||||
patch=$$(echo "$$version_num" | cut -d. -f3); \
|
||||
new_patch=$$((patch + 1)); \
|
||||
version="v$$major.$$minor.$$new_patch"; \
|
||||
echo "Auto-incrementing to: $$version"; \
|
||||
else \
|
||||
version="$(VERSION)"; \
|
||||
if ! echo "$$version" | grep -q "^v"; then \
|
||||
version="v$$version"; \
|
||||
fi; \
|
||||
fi; \
|
||||
echo "Creating release: $$version"; \
|
||||
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo ""); \
|
||||
if [ -z "$$latest_tag" ]; then \
|
||||
commit_logs=$$(git log --pretty=format:"- %s" --no-merges); \
|
||||
else \
|
||||
commit_logs=$$(git log "$${latest_tag}..HEAD" --pretty=format:"- %s" --no-merges); \
|
||||
fi; \
|
||||
if [ -z "$$commit_logs" ]; then \
|
||||
tag_message="Release $$version"; \
|
||||
else \
|
||||
tag_message="Release $$version\n\n$$commit_logs"; \
|
||||
fi; \
|
||||
git tag -a "$$version" -m "$$tag_message"; \
|
||||
git push origin "$$version"; \
|
||||
echo "Tag $$version created and pushed to remote repository."
|
||||
|
||||
|
||||
lint: ## Run linter
|
||||
@echo "Running linter..."
|
||||
@if command -v golangci-lint > /dev/null; then \
|
||||
golangci-lint run --config=.golangci.json; \
|
||||
else \
|
||||
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
lintfix: ## Run linter
|
||||
@echo "Running linter..."
|
||||
@if command -v golangci-lint > /dev/null; then \
|
||||
golangci-lint run --config=.golangci.json --fix; \
|
||||
else \
|
||||
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
|
||||
# Start PostgreSQL for integration tests
|
||||
docker-up:
|
||||
@echo "Starting PostgreSQL container..."
|
||||
@docker-compose up -d postgres-test
|
||||
@podman compose up -d postgres-test
|
||||
@echo "Waiting for PostgreSQL to be ready..."
|
||||
@sleep 5
|
||||
@echo "PostgreSQL is ready!"
|
||||
@@ -24,12 +77,12 @@ docker-up:
|
||||
# Stop PostgreSQL container
|
||||
docker-down:
|
||||
@echo "Stopping PostgreSQL container..."
|
||||
@docker-compose down
|
||||
@podman compose down
|
||||
|
||||
# Clean up Docker volumes and test data
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@docker-compose down -v
|
||||
@podman compose down -v
|
||||
@echo "Cleanup complete!"
|
||||
|
||||
# Run integration tests with Docker (full workflow)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
@@ -15,7 +17,6 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
gormlog "gorm.io/gorm/logger"
|
||||
)
|
||||
@@ -40,12 +41,14 @@ func main() {
|
||||
logger.Info("ResolveSpec test server starting")
|
||||
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
|
||||
|
||||
// Initialize database
|
||||
db, err := initDB(cfg)
|
||||
// Initialize database manager
|
||||
ctx := context.Background()
|
||||
dbMgr, db, err := initDB(ctx, cfg)
|
||||
if err != nil {
|
||||
logger.Error("Failed to initialize database: %+v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer dbMgr.Close()
|
||||
|
||||
// Create router
|
||||
r := mux.NewRouter()
|
||||
@@ -67,9 +70,36 @@ func main() {
|
||||
// Setup routes using new SetupMuxRoutes function (without authentication)
|
||||
resolvespec.SetupMuxRoutes(r, handler, nil)
|
||||
|
||||
// Create graceful server with configuration
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: cfg.Server.Addr,
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Parse host and port from addr
|
||||
host := ""
|
||||
port := 8080
|
||||
if cfg.Server.Addr != "" {
|
||||
// Parse addr (format: ":8080" or "localhost:8080")
|
||||
if cfg.Server.Addr[0] == ':' {
|
||||
// Just port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
// Host and port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add server instance
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "api",
|
||||
Host: host,
|
||||
Port: port,
|
||||
Handler: r,
|
||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
||||
DrainTimeout: cfg.Server.DrainTimeout,
|
||||
@@ -77,16 +107,20 @@ func main() {
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to add server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Start server with graceful shutdown
|
||||
logger.Info("Starting server on %s", cfg.Server.Addr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Server failed to start: %v", err)
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
logger.Error("Server failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
func initDB(ctx context.Context, cfg *config.Config) (dbmanager.Manager, *gorm.DB, error) {
|
||||
// Configure GORM logger based on config
|
||||
logLevel := gormlog.Info
|
||||
if !cfg.Logger.Dev {
|
||||
@@ -104,25 +138,41 @@ func initDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
},
|
||||
)
|
||||
|
||||
// Use database URL from config if available, otherwise use default SQLite
|
||||
dbURL := cfg.Database.URL
|
||||
if dbURL == "" {
|
||||
dbURL = "test.db"
|
||||
// Create database manager from config
|
||||
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create database manager: %w", err)
|
||||
}
|
||||
|
||||
// Create SQLite database
|
||||
db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Connect all databases
|
||||
if err := mgr.Connect(ctx); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect databases: %w", err)
|
||||
}
|
||||
|
||||
// Get default connection
|
||||
conn, err := mgr.GetDefault()
|
||||
if err != nil {
|
||||
mgr.Close()
|
||||
return nil, nil, fmt.Errorf("failed to get default connection: %w", err)
|
||||
}
|
||||
|
||||
// Get GORM database
|
||||
gormDB, err := conn.GORM()
|
||||
if err != nil {
|
||||
mgr.Close()
|
||||
return nil, nil, fmt.Errorf("failed to get GORM database: %w", err)
|
||||
}
|
||||
|
||||
// Update GORM logger
|
||||
gormDB.Logger = newLogger
|
||||
|
||||
modelList := testmodels.GetTestModels()
|
||||
|
||||
// Auto migrate schemas
|
||||
err = db.AutoMigrate(modelList...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := gormDB.AutoMigrate(modelList...); err != nil {
|
||||
mgr.Close()
|
||||
return nil, nil, fmt.Errorf("failed to auto migrate: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
return mgr, gormDB, nil
|
||||
}
|
||||
|
||||
24
config.yaml
24
config.yaml
@@ -37,5 +37,25 @@ cors:
|
||||
tracing:
|
||||
enabled: false
|
||||
|
||||
database:
|
||||
url: "" # Empty means use default SQLite (test.db)
|
||||
# Database Manager Configuration
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 5
|
||||
conn_max_lifetime: 30m
|
||||
conn_max_idle_time: 5m
|
||||
retry_attempts: 3
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
type: "sqlite"
|
||||
filepath: "test.db"
|
||||
default_orm: "gorm"
|
||||
enable_logging: true
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
99
go.mod
99
go.mod
@@ -5,92 +5,157 @@ go 1.24.0
|
||||
toolchain go1.24.6
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1
|
||||
github.com/getsentry/sentry-go v0.40.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.6.0
|
||||
github.com/klauspost/compress v1.18.0
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/microsoft/go-mssqldb v1.9.5
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
github.com/nats-io/nats.go v1.48.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.40.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/uptrace/bun v1.2.15
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
||||
github.com/uptrace/bun v1.2.16
|
||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
||||
github.com/uptrace/bunrouter v1.0.23
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
go.opentelemetry.io/otel v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.25.12
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/driver/sqlserver v1.6.3
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/nats-io/nkeys v0.4.11 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rs/xid v1.4.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/spf13/viper v1.21.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.45.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/libc v1.67.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
modernc.org/sqlite v1.40.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||
|
||||
392
go.sum
392
go.sum
@@ -1,5 +1,37 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||
@@ -8,42 +40,102 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE=
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -52,6 +144,14 @@ github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
|
||||
github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo=
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -59,24 +159,78 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
||||
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
|
||||
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI=
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc=
|
||||
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
|
||||
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
|
||||
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
|
||||
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
|
||||
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
@@ -91,10 +245,20 @@ github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5i
|
||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
@@ -106,12 +270,24 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -121,28 +297,53 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16 h1:rKv0cKPNBviXadB/+2Y/UedA/c1JnwGzUWZkdN5FdSQ=
|
||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16/go.mod h1:J5U7tGKWDsx2Q7MwDZF2417jCdpD6yD/ZMFJcCR80bk=
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
|
||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
|
||||
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
@@ -163,25 +364,128 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
||||
@@ -195,25 +499,37 @@ google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXn
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
|
||||
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
||||
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
@@ -222,8 +538,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
20
pkg/cache/cache_manager.go
vendored
20
pkg/cache/cache_manager.go
vendored
@@ -57,11 +57,31 @@ func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time
|
||||
return c.provider.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// SetWithTags serializes and stores a value in the cache with the specified TTL and tags.
|
||||
func (c *Cache) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags []string) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize: %w", err)
|
||||
}
|
||||
|
||||
return c.provider.SetWithTags(ctx, key, data, ttl, tags)
|
||||
}
|
||||
|
||||
// SetBytesWithTags stores raw bytes in the cache with the specified TTL and tags.
|
||||
func (c *Cache) SetBytesWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
return c.provider.SetWithTags(ctx, key, value, ttl, tags)
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||
return c.provider.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (c *Cache) DeleteByTag(ctx context.Context, tag string) error {
|
||||
return c.provider.DeleteByTag(ctx, tag)
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
return c.provider.DeleteByPattern(ctx, pattern)
|
||||
|
||||
8
pkg/cache/provider.go
vendored
8
pkg/cache/provider.go
vendored
@@ -15,9 +15,17 @@ type Provider interface {
|
||||
// If ttl is 0, the item never expires.
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
// Tags can be used to invalidate groups of related keys.
|
||||
// If ttl is 0, the item never expires.
|
||||
SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
DeleteByTag(ctx context.Context, tag string) error
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Pattern syntax depends on the provider implementation.
|
||||
DeleteByPattern(ctx context.Context, pattern string) error
|
||||
|
||||
140
pkg/cache/provider_memcache.go
vendored
140
pkg/cache/provider_memcache.go
vendored
@@ -2,6 +2,7 @@ package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -97,8 +98,115 @@ func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, tt
|
||||
return m.client.Set(item)
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
// Note: Tag support in Memcache is limited and less efficient than Redis.
|
||||
func (m *MemcacheProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
expiration := int32(ttl.Seconds())
|
||||
|
||||
// Set the main value
|
||||
item := &memcache.Item{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
}
|
||||
if err := m.client.Set(item); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store tags for this key
|
||||
if len(tags) > 0 {
|
||||
tagsData, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||
}
|
||||
|
||||
tagsItem := &memcache.Item{
|
||||
Key: fmt.Sprintf("cache:tags:%s", key),
|
||||
Value: tagsData,
|
||||
Expiration: expiration,
|
||||
}
|
||||
if err := m.client.Set(tagsItem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add key to each tag's key list
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get existing keys for this tag
|
||||
var keys []string
|
||||
if item, err := m.client.Get(tagKey); err == nil {
|
||||
_ = json.Unmarshal(item.Value, &keys)
|
||||
}
|
||||
|
||||
// Add current key if not already present
|
||||
found := false
|
||||
for _, k := range keys {
|
||||
if k == key {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Store updated key list
|
||||
keysData, err := json.Marshal(keys)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tagItem := &memcache.Item{
|
||||
Key: tagKey,
|
||||
Value: keysData,
|
||||
Expiration: expiration + 3600, // Give tag lists longer TTL
|
||||
}
|
||||
_ = m.client.Set(tagItem)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||
// Get tags for this key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
if item, err := m.client.Get(tagsKey); err == nil {
|
||||
var tags []string
|
||||
if err := json.Unmarshal(item.Value, &tags); err == nil {
|
||||
// Remove key from each tag's key list
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
if tagItem, err := m.client.Get(tagKey); err == nil {
|
||||
var keys []string
|
||||
if err := json.Unmarshal(tagItem.Value, &keys); err == nil {
|
||||
// Remove current key from the list
|
||||
newKeys := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
if k != key {
|
||||
newKeys = append(newKeys, k)
|
||||
}
|
||||
}
|
||||
// Update the tag's key list
|
||||
if keysData, err := json.Marshal(newKeys); err == nil {
|
||||
tagItem.Value = keysData
|
||||
_ = m.client.Set(tagItem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Delete the tags key
|
||||
_ = m.client.Delete(tagsKey)
|
||||
}
|
||||
|
||||
// Delete the actual key
|
||||
err := m.client.Delete(key)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil
|
||||
@@ -106,6 +214,38 @@ func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (m *MemcacheProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get all keys associated with this tag
|
||||
item, err := m.client.Get(tagKey)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
if err := json.Unmarshal(item.Value, &keys); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal tag keys: %w", err)
|
||||
}
|
||||
|
||||
// Delete all keys
|
||||
for _, key := range keys {
|
||||
_ = m.client.Delete(key)
|
||||
// Also delete the tags key for this cache key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
_ = m.client.Delete(tagsKey)
|
||||
}
|
||||
|
||||
// Delete the tag key itself
|
||||
_ = m.client.Delete(tagKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Note: Memcache does not support pattern-based deletion natively.
|
||||
// This is a no-op for memcache and returns an error.
|
||||
|
||||
118
pkg/cache/provider_memory.go
vendored
118
pkg/cache/provider_memory.go
vendored
@@ -15,6 +15,7 @@ type memoryItem struct {
|
||||
Expiration time.Time
|
||||
LastAccess time.Time
|
||||
HitCount int64
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// isExpired checks if the item has expired.
|
||||
@@ -27,11 +28,12 @@ func (m *memoryItem) isExpired() bool {
|
||||
|
||||
// MemoryProvider is an in-memory implementation of the Provider interface.
|
||||
type MemoryProvider struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryItem
|
||||
options *Options
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryItem
|
||||
tagToKeys map[string]map[string]struct{} // tag -> set of keys
|
||||
options *Options
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
}
|
||||
|
||||
// NewMemoryProvider creates a new in-memory cache provider.
|
||||
@@ -44,8 +46,9 @@ func NewMemoryProvider(opts *Options) *MemoryProvider {
|
||||
}
|
||||
|
||||
return &MemoryProvider{
|
||||
items: make(map[string]*memoryItem),
|
||||
options: opts,
|
||||
items: make(map[string]*memoryItem),
|
||||
tagToKeys: make(map[string]map[string]struct{}),
|
||||
options: opts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,15 +117,116 @@ func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
func (m *MemoryProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
var expiration time.Time
|
||||
if ttl > 0 {
|
||||
expiration = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// Check max size and evict if necessary
|
||||
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||
if _, exists := m.items[key]; !exists {
|
||||
m.evictOne()
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old tag associations if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
for _, tag := range oldItem.Tags {
|
||||
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||
delete(keySet, key)
|
||||
if len(keySet) == 0 {
|
||||
delete(m.tagToKeys, tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store the item
|
||||
m.items[key] = &memoryItem{
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
LastAccess: time.Now(),
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
// Add new tag associations
|
||||
for _, tag := range tags {
|
||||
if m.tagToKeys[tag] == nil {
|
||||
m.tagToKeys[tag] = make(map[string]struct{})
|
||||
}
|
||||
m.tagToKeys[tag][key] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Remove tag associations
|
||||
if item, exists := m.items[key]; exists {
|
||||
for _, tag := range item.Tags {
|
||||
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||
delete(keySet, key)
|
||||
if len(keySet) == 0 {
|
||||
delete(m.tagToKeys, tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(m.items, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (m *MemoryProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Get all keys associated with this tag
|
||||
keySet, exists := m.tagToKeys[tag]
|
||||
if !exists {
|
||||
return nil // No keys with this tag
|
||||
}
|
||||
|
||||
// Delete all items with this tag
|
||||
for key := range keySet {
|
||||
if item, ok := m.items[key]; ok {
|
||||
// Remove this tag from the item's tag list
|
||||
newTags := make([]string, 0, len(item.Tags))
|
||||
for _, t := range item.Tags {
|
||||
if t != tag {
|
||||
newTags = append(newTags, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If item has no more tags, delete it
|
||||
// Otherwise update its tags
|
||||
if len(newTags) == 0 {
|
||||
delete(m.items, key)
|
||||
} else {
|
||||
item.Tags = newTags
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the tag mapping
|
||||
delete(m.tagToKeys, tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
86
pkg/cache/provider_redis.go
vendored
86
pkg/cache/provider_redis.go
vendored
@@ -103,9 +103,93 @@ func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl t
|
||||
return r.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
func (r *RedisProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
if ttl == 0 {
|
||||
ttl = r.options.DefaultTTL
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Set the value
|
||||
pipe.Set(ctx, key, value, ttl)
|
||||
|
||||
// Add key to each tag's set
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
pipe.SAdd(ctx, tagKey, key)
|
||||
// Set expiration on tag set (longer than cache items to ensure cleanup)
|
||||
if ttl > 0 {
|
||||
pipe.Expire(ctx, tagKey, ttl+time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// Store tags for this key for later cleanup
|
||||
if len(tags) > 0 {
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
pipe.SAdd(ctx, tagsKey, tags)
|
||||
if ttl > 0 {
|
||||
pipe.Expire(ctx, tagsKey, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||
return r.client.Del(ctx, key).Err()
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Get tags for this key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
tags, err := r.client.SMembers(ctx, tagsKey).Result()
|
||||
if err == nil && len(tags) > 0 {
|
||||
// Remove key from each tag set
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
pipe.SRem(ctx, tagKey, key)
|
||||
}
|
||||
// Delete the tags key
|
||||
pipe.Del(ctx, tagsKey)
|
||||
}
|
||||
|
||||
// Delete the actual key
|
||||
pipe.Del(ctx, key)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (r *RedisProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get all keys associated with this tag
|
||||
keys, err := r.client.SMembers(ctx, tagKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Delete all keys and their tag associations
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
// Also delete the tags key for this cache key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
pipe.Del(ctx, tagsKey)
|
||||
}
|
||||
|
||||
// Delete the tag set itself
|
||||
pipe.Del(ctx, tagKey)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
|
||||
151
pkg/cache/query_cache_test.go
vendored
151
pkg/cache/query_cache_test.go
vendored
@@ -1,151 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestBuildQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
{Column: "age", Operator: "gt", Value: 25},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
||||
}
|
||||
|
||||
// Different parameters should generate different key
|
||||
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
expandOpts := []interface{}{
|
||||
map[string]interface{}{
|
||||
"relation": "posts",
|
||||
"where": "status = 'published'",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters")
|
||||
}
|
||||
|
||||
// Different distinct value should generate different key
|
||||
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different distinct values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQueryTotalCacheKey(t *testing.T) {
|
||||
hash := "abc123"
|
||||
key := GetQueryTotalCacheKey(hash)
|
||||
|
||||
expected := "query_total:abc123"
|
||||
if key != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedTotalIntegration(t *testing.T) {
|
||||
// Initialize cache with memory provider for testing
|
||||
UseMemory(&Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 100,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "created_at", Direction: "desc"},
|
||||
}
|
||||
|
||||
// Build cache key
|
||||
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
||||
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Store a total count in cache
|
||||
totalToCache := CachedTotal{Total: 42}
|
||||
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set cache: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve from cache
|
||||
var cachedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get from cache: %v", err)
|
||||
}
|
||||
|
||||
if cachedTotal.Total != 42 {
|
||||
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
||||
}
|
||||
|
||||
// Test cache miss
|
||||
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
||||
var missedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for cache miss, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashString(t *testing.T) {
|
||||
input1 := "test string"
|
||||
input2 := "test string"
|
||||
input3 := "different string"
|
||||
|
||||
hash1 := hashString(input1)
|
||||
hash2 := hashString(input2)
|
||||
hash3 := hashString(input3)
|
||||
|
||||
// Same input should produce same hash
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Expected same hash for identical inputs")
|
||||
}
|
||||
|
||||
// Different input should produce different hash
|
||||
if hash1 == hash3 {
|
||||
t.Errorf("Expected different hash for different inputs")
|
||||
}
|
||||
|
||||
// Hash should be hex encoded SHA256 (64 characters)
|
||||
if len(hash1) != 64 {
|
||||
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
||||
}
|
||||
}
|
||||
@@ -34,6 +34,63 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
|
||||
}
|
||||
}
|
||||
|
||||
// debugScanIntoStruct attempts to scan rows into a struct with detailed field-level logging
|
||||
// This helps identify which specific field is causing scanning issues
|
||||
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct && v.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be pointer to struct or slice")
|
||||
}
|
||||
|
||||
// Log the type being scanned into
|
||||
typeName := v.Type().String()
|
||||
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
|
||||
|
||||
// Handle slice types - inspect the element type
|
||||
var structType reflect.Type
|
||||
if v.Kind() == reflect.Slice {
|
||||
elemType := v.Type().Elem()
|
||||
logger.Debug(" Slice element type: %s", elemType)
|
||||
|
||||
// If slice of pointers, get the underlying type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
structType = elemType.Elem()
|
||||
} else {
|
||||
structType = elemType
|
||||
}
|
||||
} else if v.Kind() == reflect.Struct {
|
||||
structType = v.Type()
|
||||
}
|
||||
|
||||
// If we have a struct type, log all its fields
|
||||
if structType != nil && structType.Kind() == reflect.Struct {
|
||||
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
field := structType.Field(i)
|
||||
|
||||
// Log embedded fields specially
|
||||
if field.Anonymous {
|
||||
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
|
||||
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
|
||||
} else {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag == "" {
|
||||
bunTag = "(no tag)"
|
||||
}
|
||||
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
|
||||
i, field.Name, field.Type, field.Type.Kind(), bunTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
@@ -52,6 +109,14 @@ func (b *BunAdapter) EnableQueryDebug() {
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
// EnableDetailedScanDebug enables verbose logging of scan operations
|
||||
// WARNING: This generates a LOT of log output. Use only for debugging specific issues.
|
||||
func (b *BunAdapter) EnableDetailedScanDebug() {
|
||||
logger.Info("Detailed scan debugging enabled - will log all field scanning operations")
|
||||
// This is a flag that can be checked in scan operations
|
||||
// Implementation would require modifying the scan logic
|
||||
}
|
||||
|
||||
// DisableQueryDebug removes all query hooks
|
||||
func (b *BunAdapter) DisableQueryDebug() {
|
||||
// Create a new DB without hooks
|
||||
@@ -131,6 +196,10 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
})
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
}
|
||||
|
||||
// BunSelectQuery implements SelectQuery for Bun
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
@@ -622,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.OrderExpr(order, args...)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
|
||||
b.query = b.query.Limit(n)
|
||||
return b
|
||||
@@ -676,6 +750,31 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Enhanced panic recovery with model information
|
||||
model := b.query.GetModel()
|
||||
var modelInfo string
|
||||
if model != nil && model.Value() != nil {
|
||||
modelValue := model.Value()
|
||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||
|
||||
// Try to get the model's underlying struct type
|
||||
v := reflect.ValueOf(modelValue)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice {
|
||||
if v.Type().Elem().Kind() == reflect.Ptr {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
|
||||
} else {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
|
||||
}
|
||||
} else if v.Kind() == reflect.Struct {
|
||||
modelInfo += fmt.Sprintf(", Struct: %s", v.Type().Name())
|
||||
}
|
||||
}
|
||||
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
@@ -683,6 +782,17 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
// Optional: Enable detailed field-level debugging (set to true to debug)
|
||||
const enableDetailedDebug = true
|
||||
if enableDetailedDebug {
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
if err := debugScanIntoStruct(nil, model.Value()); err != nil {
|
||||
logger.Warn("Debug scan inspection failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
@@ -1107,3 +1217,7 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error {
|
||||
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
return fn(b) // Already in transaction
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.tx
|
||||
}
|
||||
|
||||
@@ -102,6 +102,10 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
})
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
}
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
@@ -382,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
// GORM's Order can handle expressions directly
|
||||
g.db = g.db.Order(gorm.Expr(order, args...))
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
|
||||
g.db = g.db.Limit(n)
|
||||
return g
|
||||
|
||||
1370
pkg/common/adapters/database/pgsql.go
Normal file
1370
pkg/common/adapters/database/pgsql.go
Normal file
File diff suppressed because it is too large
Load Diff
176
pkg/common/adapters/database/pgsql_example.go
Normal file
176
pkg/common/adapters/database/pgsql_example.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Example demonstrates how to use the PgSQL adapter
|
||||
func ExamplePgSQLAdapter() error {
|
||||
// Connect to PostgreSQL database
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create the PgSQL adapter
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
|
||||
// Enable query debugging (optional)
|
||||
adapter.EnableQueryDebug()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Simple SELECT query
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("age > ?", 18).
|
||||
Order("created_at DESC").
|
||||
Limit(10).
|
||||
Scan(ctx, &results)
|
||||
if err != nil {
|
||||
return fmt.Errorf("select failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 2: INSERT query
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Returning("id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 3: UPDATE query
|
||||
result, err = adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 4: DELETE query
|
||||
result, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("age < ?", 18).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 5: Using transactions
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Insert a new user
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Transaction User").
|
||||
Value("email", "tx@example.com").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update another user
|
||||
_, err = tx.NewUpdate().
|
||||
Table("users").
|
||||
Set("verified", true).
|
||||
Where("email = ?", "tx@example.com").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Both operations succeed or both rollback
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("transaction failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 6: JOIN query
|
||||
err = adapter.NewSelect().
|
||||
Table("users u").
|
||||
Column("u.id", "u.name", "p.title as post_title").
|
||||
LeftJoin("posts p ON p.user_id = u.id").
|
||||
Where("u.active = ?", true).
|
||||
Scan(ctx, &results)
|
||||
if err != nil {
|
||||
return fmt.Errorf("join query failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 7: Aggregation query
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("active = ?", true).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Active users: %d\n", count)
|
||||
|
||||
// Example 8: Raw SQL execution
|
||||
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
||||
if err != nil {
|
||||
return fmt.Errorf("raw exec failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 9: Raw SQL query
|
||||
var users []map[string]interface{}
|
||||
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("raw query failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// User is an example model
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
// TableName implements common.TableNameProvider
|
||||
func (u User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// ExampleWithModel demonstrates using models with the PgSQL adapter
|
||||
func ExampleWithModel() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Use model with adapter
|
||||
user := User{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&user).
|
||||
Where("id = ?", 1).
|
||||
Scan(ctx, &user)
|
||||
|
||||
return err
|
||||
}
|
||||
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
@@ -0,0 +1,526 @@
|
||||
// +build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
// Integration test models
|
||||
type IntegrationUser struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Age int `db:"age"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
|
||||
}
|
||||
|
||||
func (u IntegrationUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type IntegrationPost struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
UserID int `db:"user_id"`
|
||||
Published bool `db:"published"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p IntegrationPost) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
type IntegrationComment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
|
||||
}
|
||||
|
||||
func (c IntegrationComment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// setupTestDB creates a PostgreSQL container and returns the connection
|
||||
func setupTestDB(t *testing.T) (*sql.DB, func()) {
|
||||
ctx := context.Background()
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: "postgres:15-alpine",
|
||||
ExposedPorts: []string{"5432/tcp"},
|
||||
Env: map[string]string{
|
||||
"POSTGRES_USER": "testuser",
|
||||
"POSTGRES_PASSWORD": "testpass",
|
||||
"POSTGRES_DB": "testdb",
|
||||
},
|
||||
WaitingFor: wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(60 * time.Second),
|
||||
}
|
||||
|
||||
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
host, err := postgres.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
port, err := postgres.MappedPort(ctx, "5432")
|
||||
require.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
|
||||
host, port.Port())
|
||||
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for database to be ready
|
||||
err = db.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create schema
|
||||
createSchema(t, db)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
postgres.Terminate(ctx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
// createSchema creates test tables
|
||||
func createSchema(t *testing.T, db *sql.DB) {
|
||||
schema := `
|
||||
DROP TABLE IF EXISTS comments CASCADE;
|
||||
DROP TABLE IF EXISTS posts CASCADE;
|
||||
DROP TABLE IF EXISTS users CASCADE;
|
||||
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
age INT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
published BOOLEAN DEFAULT false,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE comments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(schema)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestIntegration_BasicCRUD tests basic CRUD operations
|
||||
func TestIntegration_BasicCRUD(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// CREATE
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// READ
|
||||
var users []IntegrationUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "john@example.com").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
assert.Equal(t, 25, users[0].Age)
|
||||
|
||||
userID := users[0].ID
|
||||
|
||||
// UPDATE
|
||||
result, err = adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("age", 26).
|
||||
Where("id = ?", userID).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify update
|
||||
var updatedUser IntegrationUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Scan(ctx, &updatedUser)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 26, updatedUser.Age)
|
||||
|
||||
// DELETE
|
||||
result, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify delete
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
// TestIntegration_ScanModel tests ScanModel functionality
|
||||
func TestIntegration_ScanModel(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test data
|
||||
_, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Jane Smith").
|
||||
Value("email", "jane@example.com").
|
||||
Value("age", 30).
|
||||
Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test single struct scan
|
||||
user := &IntegrationUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(user).
|
||||
Table("users").
|
||||
Where("email = ?", "jane@example.com").
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Jane Smith", user.Name)
|
||||
assert.Equal(t, 30, user.Age)
|
||||
|
||||
// Test slice scan
|
||||
users := []*IntegrationUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&users).
|
||||
Table("users").
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
}
|
||||
|
||||
// TestIntegration_Transaction tests transaction handling
|
||||
func TestIntegration_Transaction(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Successful transaction
|
||||
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Alice").
|
||||
Value("email", "alice@example.com").
|
||||
Value("age", 28).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Bob").
|
||||
Value("email", "bob@example.com").
|
||||
Value("age", 32).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both records exist
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// Failed transaction (should rollback)
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Charlie").
|
||||
Value("email", "charlie@example.com").
|
||||
Value("age", 35).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Intentional error - duplicate email
|
||||
_, err = tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "David").
|
||||
Value("email", "alice@example.com"). // Duplicate
|
||||
Value("age", 40).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
// Verify rollback - count should still be 2
|
||||
count, err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
// TestIntegration_Preload tests basic preload functionality
|
||||
func TestIntegration_Preload(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
|
||||
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
|
||||
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
|
||||
|
||||
// Test Preload
|
||||
var users []*IntegrationUser
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationUser{}).
|
||||
Table("users").
|
||||
Preload("Posts").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.NotNil(t, users[0].Posts)
|
||||
assert.Len(t, users[0].Posts, 2)
|
||||
}
|
||||
|
||||
// TestIntegration_PreloadRelation tests smart PreloadRelation
|
||||
func TestIntegration_PreloadRelation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
|
||||
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
|
||||
createTestComment(t, adapter, ctx, postID, "Great post!")
|
||||
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
|
||||
|
||||
// Test PreloadRelation with belongs-to (should use JOIN)
|
||||
var posts []*IntegrationPost
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("User").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
// Note: JOIN preloading needs proper column selection to work
|
||||
// For now, we test that it doesn't error
|
||||
|
||||
// Test PreloadRelation with has-many (should use subquery)
|
||||
posts = []*IntegrationPost{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Comments").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
if posts[0].Comments != nil {
|
||||
assert.Len(t, posts[0].Comments, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_JoinRelation tests explicit JoinRelation
|
||||
func TestIntegration_JoinRelation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
|
||||
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
|
||||
|
||||
// Test JoinRelation
|
||||
var posts []*IntegrationPost
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
JoinRelation("User").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
}
|
||||
|
||||
// TestIntegration_ComplexQuery tests complex queries
|
||||
func TestIntegration_ComplexQuery(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
|
||||
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
|
||||
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
|
||||
|
||||
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
|
||||
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
|
||||
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
|
||||
|
||||
// Complex query with joins, where, order, limit
|
||||
var results []map[string]interface{}
|
||||
err := adapter.NewSelect().
|
||||
Table("posts p").
|
||||
Column("p.title", "u.name as author_name", "u.age as author_age").
|
||||
LeftJoin("users u ON u.id = p.user_id").
|
||||
Where("p.published = ?", true).
|
||||
WhereOr("u.age > ?", 25).
|
||||
Order("u.age DESC").
|
||||
Limit(2).
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 2)
|
||||
}
|
||||
|
||||
// TestIntegration_Aggregation tests aggregation queries
|
||||
func TestIntegration_Aggregation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
|
||||
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
|
||||
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
|
||||
|
||||
// Test Count
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("age >= ?", 25).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// Test Exists
|
||||
exists, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "user1@example.com").
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Test Group By with aggregation
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Column("age", "COUNT(*) as count").
|
||||
Group("age").
|
||||
Having("COUNT(*) > ?", 0).
|
||||
Order("age ASC").
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 3)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
|
||||
var userID int
|
||||
err := adapter.Query(ctx, &userID,
|
||||
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
|
||||
name, email, age)
|
||||
require.NoError(t, err)
|
||||
return userID
|
||||
}
|
||||
|
||||
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
|
||||
var postID int
|
||||
err := adapter.Query(ctx, &postID,
|
||||
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
|
||||
title, content, userID, published)
|
||||
require.NoError(t, err)
|
||||
return postID
|
||||
}
|
||||
|
||||
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
|
||||
var commentID int
|
||||
err := adapter.Query(ctx, &commentID,
|
||||
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
|
||||
content, postID)
|
||||
require.NoError(t, err)
|
||||
return commentID
|
||||
}
|
||||
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Example models for demonstrating preload functionality
|
||||
|
||||
// Author model - has many Posts
|
||||
type Author struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
|
||||
}
|
||||
|
||||
func (a Author) TableName() string {
|
||||
return "authors"
|
||||
}
|
||||
|
||||
// Post model - belongs to Author, has many Comments
|
||||
type Post struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
AuthorID int `db:"author_id"`
|
||||
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
|
||||
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p Post) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
// Comment model - belongs to Post
|
||||
type Comment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
|
||||
}
|
||||
|
||||
func (c Comment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// ExamplePreload demonstrates the Preload functionality
|
||||
func ExamplePreload() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Simple Preload (uses subquery for has-many)
|
||||
var authors []*Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
Preload("Posts"). // Load all posts for each author
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now authors[i].Posts will be populated with their posts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
|
||||
func ExamplePreloadRelation() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
|
||||
var authors []*Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("published = ?", true).Order("created_at DESC")
|
||||
}).
|
||||
Where("active = ?", true).
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
|
||||
var posts []*Post
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
|
||||
Scan(ctx, &posts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 3: Nested preloads
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
// First load posts, then preload comments for each post
|
||||
return q.Limit(10)
|
||||
}).
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Manually load nested relationships (two-level preloading)
|
||||
for _, author := range authors {
|
||||
if author.Posts != nil {
|
||||
for _, post := range author.Posts {
|
||||
var comments []*Comment
|
||||
err := adapter.NewSelect().
|
||||
Table("comments").
|
||||
Where("post_id = ?", post.ID).
|
||||
Scan(ctx, &comments)
|
||||
if err == nil {
|
||||
post.Comments = comments
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleJoinRelation demonstrates explicit JOIN loading
|
||||
func ExampleJoinRelation() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Force JOIN for belongs-to relationship
|
||||
var posts []*Post
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts").
|
||||
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &posts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: Multiple JOINs
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts p").
|
||||
Column("p.*", "a.name as author_name", "a.email as author_email").
|
||||
LeftJoin("authors a ON a.id = p.author_id").
|
||||
Where("p.published = ?", true).
|
||||
Scan(ctx, &posts)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExampleScanModel demonstrates ScanModel with struct destinations
|
||||
func ExampleScanModel() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Scan single struct
|
||||
author := Author{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&author).
|
||||
Table("authors").
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: Scan slice of structs
|
||||
authors := []*Author{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&authors).
|
||||
Table("authors").
|
||||
Where("active = ?", true).
|
||||
Limit(10).
|
||||
ScanModel(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
|
||||
func ExampleCompleteWorkflow() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
adapter.EnableQueryDebug() // Enable query logging
|
||||
ctx := context.Background()
|
||||
|
||||
// Step 1: Create an author
|
||||
author := &Author{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
}
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Table("authors").
|
||||
Value("name", author.Name).
|
||||
Value("email", author.Email).
|
||||
Returning("id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = result
|
||||
|
||||
// Step 2: Load author with all their posts
|
||||
var loadedAuthor Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&loadedAuthor).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Order("created_at DESC").Limit(5)
|
||||
}).
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 3: Update author name
|
||||
_, err = adapter.NewUpdate().
|
||||
Table("authors").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
629
pkg/common/adapters/database/pgsql_test.go
Normal file
629
pkg/common/adapters/database/pgsql_test.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Test models
|
||||
type TestUser struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
|
||||
func (u TestUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
UserID int `db:"user_id"`
|
||||
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p TestPost) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
type TestComment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
}
|
||||
|
||||
func (c TestComment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// TestNewPgSQLAdapter tests adapter creation
|
||||
func TestNewPgSQLAdapter(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Equal(t, db, adapter.db)
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
|
||||
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*PgSQLSelectQuery)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
},
|
||||
expected: "SELECT * FROM users",
|
||||
},
|
||||
{
|
||||
name: "select with columns",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.columns = []string{"id", "name", "email"}
|
||||
},
|
||||
expected: "SELECT id, name, email FROM users",
|
||||
},
|
||||
{
|
||||
name: "select with where",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.whereClauses = []string{"age > $1"}
|
||||
q.args = []interface{}{18}
|
||||
},
|
||||
expected: "SELECT * FROM users WHERE (age > $1)",
|
||||
},
|
||||
{
|
||||
name: "select with order and limit",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.orderBy = []string{"created_at DESC"}
|
||||
q.limit = 10
|
||||
q.offset = 5
|
||||
},
|
||||
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
|
||||
},
|
||||
{
|
||||
name: "select with join",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
|
||||
},
|
||||
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
|
||||
},
|
||||
{
|
||||
name: "select with group and having",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.groupBy = []string{"country"}
|
||||
q.havingClauses = []string{"COUNT(*) > $1"}
|
||||
q.args = []interface{}{5}
|
||||
},
|
||||
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{
|
||||
columns: []string{"*"},
|
||||
}
|
||||
tt.setup(q)
|
||||
sql := q.buildSQL()
|
||||
assert.Equal(t, tt.expected, sql)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
|
||||
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
argCount int
|
||||
paramCounter int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single placeholder",
|
||||
query: "age > ?",
|
||||
argCount: 1,
|
||||
paramCounter: 0,
|
||||
expected: "age > $1",
|
||||
},
|
||||
{
|
||||
name: "multiple placeholders",
|
||||
query: "age > ? AND status = ?",
|
||||
argCount: 2,
|
||||
paramCounter: 0,
|
||||
expected: "age > $1 AND status = $2",
|
||||
},
|
||||
{
|
||||
name: "with existing counter",
|
||||
query: "name = ?",
|
||||
argCount: 1,
|
||||
paramCounter: 5,
|
||||
expected: "name = $6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
|
||||
result := q.replacePlaceholders(tt.query, tt.argCount)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Chaining tests method chaining
|
||||
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
query := adapter.NewSelect().
|
||||
Table("users").
|
||||
Column("id", "name").
|
||||
Where("age > ?", 18).
|
||||
Order("name ASC").
|
||||
Limit(10).
|
||||
Offset(5)
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Equal(t, "users", pgQuery.tableName)
|
||||
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
|
||||
assert.Len(t, pgQuery.whereClauses, 1)
|
||||
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
|
||||
assert.Equal(t, 10, pgQuery.limit)
|
||||
assert.Equal(t, 5, pgQuery.offset)
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Model tests model setting
|
||||
func TestPgSQLSelectQuery_Model(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
user := &TestUser{}
|
||||
query := adapter.NewSelect().Model(user)
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Equal(t, "users", pgQuery.tableName)
|
||||
assert.Equal(t, user, pgQuery.model)
|
||||
}
|
||||
|
||||
// TestScanRowsToStructSlice tests scanning rows into struct slice
|
||||
func TestScanRowsToStructSlice(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25).
|
||||
AddRow(2, "Jane Smith", "jane@example.com", 30)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var users []TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2)
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
assert.Equal(t, "jane@example.com", users[1].Email)
|
||||
assert.Equal(t, 30, users[1].Age)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
|
||||
func TestScanRowsToStructSlicePointers(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var users []*TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.NotNil(t, users[0])
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToSingleStruct tests scanning a single row
|
||||
func TestScanRowsToSingleStruct(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var user TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Scan(ctx, &user)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, user.ID)
|
||||
assert.Equal(t, "John Doe", user.Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToMapSlice tests scanning into map slice
|
||||
func TestScanRowsToMapSlice(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
|
||||
AddRow(1, "John Doe", "john@example.com").
|
||||
AddRow(2, "Jane Smith", "jane@example.com")
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.Equal(t, int64(1), results[0]["id"])
|
||||
assert.Equal(t, "John Doe", results[0]["name"])
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLInsertQuery_Exec tests insert query execution
|
||||
func TestPgSQLInsertQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectExec("INSERT INTO users").
|
||||
WithArgs("John Doe", "john@example.com", 25).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLUpdateQuery_Exec tests update query execution
|
||||
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Note: Args order is SET values first, then WHERE values
|
||||
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
|
||||
WithArgs("Jane Doe", 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLDeleteQuery_Exec tests delete query execution
|
||||
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Count tests count query
|
||||
func TestPgSQLSelectQuery_Count(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 42, count)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Exists tests exists query
|
||||
func TestPgSQLSelectQuery_Exists(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
exists, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "john@example.com").
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLAdapter_Transaction tests transaction handling
|
||||
func TestPgSQLAdapter_Transaction(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John").
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
|
||||
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
|
||||
mock.ExpectRollback()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John").
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestBuildFieldMap tests field mapping construction
|
||||
func TestBuildFieldMap(t *testing.T) {
|
||||
userType := reflect.TypeOf(TestUser{})
|
||||
fieldMap := buildFieldMap(userType, nil)
|
||||
|
||||
assert.NotEmpty(t, fieldMap)
|
||||
|
||||
// Check that fields are mapped
|
||||
assert.Contains(t, fieldMap, "id")
|
||||
assert.Contains(t, fieldMap, "name")
|
||||
assert.Contains(t, fieldMap, "email")
|
||||
assert.Contains(t, fieldMap, "age")
|
||||
|
||||
// Check field info
|
||||
idInfo := fieldMap["id"]
|
||||
assert.Equal(t, "ID", idInfo.Name)
|
||||
}
|
||||
|
||||
// TestGetRelationMetadata tests relationship metadata extraction
|
||||
func TestGetRelationMetadata(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{
|
||||
model: &TestPost{},
|
||||
}
|
||||
|
||||
// Test belongs-to relationship
|
||||
meta := q.getRelationMetadata("User")
|
||||
assert.NotNil(t, meta)
|
||||
assert.Equal(t, "User", meta.fieldName)
|
||||
|
||||
// Test has-many relationship
|
||||
meta = q.getRelationMetadata("Comments")
|
||||
assert.NotNil(t, meta)
|
||||
assert.Equal(t, "Comments", meta.fieldName)
|
||||
}
|
||||
|
||||
// TestPreloadConfiguration tests preload configuration
|
||||
func TestPreloadConfiguration(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
|
||||
// Test Preload
|
||||
query := adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
Preload("User")
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||
assert.False(t, pgQuery.preloads[0].useJoin)
|
||||
|
||||
// Test PreloadRelation
|
||||
query = adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Comments")
|
||||
|
||||
pgQuery = query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
|
||||
|
||||
// Test JoinRelation
|
||||
query = adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
JoinRelation("User")
|
||||
|
||||
pgQuery = query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||
assert.True(t, pgQuery.preloads[0].useJoin)
|
||||
}
|
||||
|
||||
// TestScanModel tests ScanModel functionality
|
||||
func TestScanModel(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &TestUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(user).
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, user.ID)
|
||||
assert.Equal(t, "John Doe", user.Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestRawSQL tests raw SQL execution
|
||||
func TestRawSQL(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Test Exec
|
||||
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test Query
|
||||
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
|
||||
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
|
||||
|
||||
var results []map[string]interface{}
|
||||
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
132
pkg/common/adapters/database/test_helpers.go
Normal file
132
pkg/common/adapters/database/test_helpers.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHelper provides utilities for database testing
|
||||
type TestHelper struct {
|
||||
DB *sql.DB
|
||||
Adapter *PgSQLAdapter
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
// NewTestHelper creates a new test helper
|
||||
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
|
||||
return &TestHelper{
|
||||
DB: db,
|
||||
Adapter: NewPgSQLAdapter(db),
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupTables truncates all test tables
|
||||
func (h *TestHelper) CleanupTables() {
|
||||
ctx := context.Background()
|
||||
tables := []string{"comments", "posts", "users"}
|
||||
|
||||
for _, table := range tables {
|
||||
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
|
||||
require.NoError(h.t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InsertUser inserts a test user and returns the ID
|
||||
func (h *TestHelper) InsertUser(name, email string, age int) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", name).
|
||||
Value("email", email).
|
||||
Value("age", age).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// InsertPost inserts a test post and returns the ID
|
||||
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("posts").
|
||||
Value("user_id", userID).
|
||||
Value("title", title).
|
||||
Value("content", content).
|
||||
Value("published", published).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// InsertComment inserts a test comment and returns the ID
|
||||
func (h *TestHelper) InsertComment(postID int, content string) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("comments").
|
||||
Value("post_id", postID).
|
||||
Value("content", content).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// AssertUserExists checks if a user exists by email
|
||||
func (h *TestHelper) AssertUserExists(email string) {
|
||||
ctx := context.Background()
|
||||
exists, err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", email).
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.True(h.t, exists, "User with email %s should exist", email)
|
||||
}
|
||||
|
||||
// AssertUserCount asserts the number of users
|
||||
func (h *TestHelper) AssertUserCount(expected int) {
|
||||
ctx := context.Background()
|
||||
count, err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.Equal(h.t, expected, count)
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email
|
||||
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
|
||||
ctx := context.Background()
|
||||
var results []map[string]interface{}
|
||||
err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", email).
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
|
||||
return results[0]
|
||||
}
|
||||
|
||||
// BeginTestTransaction starts a transaction for testing
|
||||
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
|
||||
ctx := context.Background()
|
||||
tx, err := h.DB.BeginTx(ctx, nil)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx}
|
||||
cleanup := func() {
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
return adapter, cleanup
|
||||
}
|
||||
@@ -1,7 +1,16 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/uptrace/bun/dialect/mssqldialect"
|
||||
"github.com/uptrace/bun/dialect/pgdialect"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
@@ -14,3 +23,39 @@ func parseTableName(fullTableName string) (schema, table string) {
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
// GetPostgresDialect returns a Bun PostgreSQL dialect
|
||||
func GetPostgresDialect() *pgdialect.Dialect {
|
||||
return pgdialect.New()
|
||||
}
|
||||
|
||||
// GetSQLiteDialect returns a Bun SQLite dialect
|
||||
func GetSQLiteDialect() *sqlitedialect.Dialect {
|
||||
return sqlitedialect.New()
|
||||
}
|
||||
|
||||
// GetMSSQLDialect returns a Bun MSSQL dialect
|
||||
func GetMSSQLDialect() *mssqldialect.Dialect {
|
||||
return mssqldialect.New()
|
||||
}
|
||||
|
||||
// GetPostgresDialector returns a GORM PostgreSQL dialector
|
||||
func GetPostgresDialector(db *sql.DB) gorm.Dialector {
|
||||
return postgres.New(postgres.Config{
|
||||
Conn: db,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSQLiteDialector returns a GORM SQLite dialector
|
||||
func GetSQLiteDialector(db *sql.DB) gorm.Dialector {
|
||||
return sqlite.Dialector{
|
||||
Conn: db,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMSSQLDialector returns a GORM MSSQL dialector
|
||||
func GetMSSQLDialector(db *sql.DB) gorm.Dialector {
|
||||
return sqlserver.New(sqlserver.Config{
|
||||
Conn: db,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,12 @@ type Database interface {
|
||||
CommitTx(ctx context.Context) error
|
||||
RollbackTx(ctx context.Context) error
|
||||
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
||||
|
||||
// GetUnderlyingDB returns the underlying database connection
|
||||
// For GORM, this returns *gorm.DB
|
||||
// For Bun, this returns *bun.DB
|
||||
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||
GetUnderlyingDB() interface{}
|
||||
}
|
||||
|
||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||
@@ -40,6 +46,7 @@ type SelectQuery interface {
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
OrderExpr(order string, args ...interface{}) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
Group(group string) SelectQuery
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -78,23 +80,69 @@ func IsTrivialCondition(cond string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
|
||||
// Returns an error if any dangerous keywords are found
|
||||
func validateWhereClauseSecurity(where string) error {
|
||||
if where == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
lowerWhere := strings.ToLower(where)
|
||||
|
||||
// List of dangerous SQL keywords that should never appear in WHERE clauses
|
||||
dangerousKeywords := []string{
|
||||
"delete ", "delete\t", "delete\n", "delete;",
|
||||
"update ", "update\t", "update\n", "update;",
|
||||
"truncate ", "truncate\t", "truncate\n", "truncate;",
|
||||
"drop ", "drop\t", "drop\n", "drop;",
|
||||
"alter ", "alter\t", "alter\n", "alter;",
|
||||
"create ", "create\t", "create\n", "create;",
|
||||
"insert ", "insert\t", "insert\n", "insert;",
|
||||
"grant ", "grant\t", "grant\n", "grant;",
|
||||
"revoke ", "revoke\t", "revoke\n", "revoke;",
|
||||
"exec ", "exec\t", "exec\n", "exec;",
|
||||
"execute ", "execute\t", "execute\n", "execute;",
|
||||
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
|
||||
}
|
||||
|
||||
for _, keyword := range dangerousKeywords {
|
||||
if strings.Contains(lowerWhere, keyword) {
|
||||
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause string to sanitize
|
||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
|
||||
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
|
||||
//
|
||||
// Returns:
|
||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
|
||||
// - An empty string if all conditions were trivial or the input was empty
|
||||
func SanitizeWhereClause(where string, tableName string) string {
|
||||
//
|
||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Validate that the WHERE clause doesn't contain dangerous SQL statements
|
||||
if err := validateWhereClauseSecurity(where); err != nil {
|
||||
logger.Debug("Security validation failed for WHERE clause: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
|
||||
@@ -104,6 +152,22 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
validColumns = getValidColumnsForTable(tableName)
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
conditions := splitByAND(where)
|
||||
|
||||
@@ -124,26 +188,29 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||
// attempt to add it
|
||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||
// Extract the column name and prefix it
|
||||
columnName := ExtractColumnName(condToCheck)
|
||||
if columnName != "" {
|
||||
// Only prefix if this is a valid column in the model
|
||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
||||
if tableName != "" && hasTablePrefix(condToCheck) {
|
||||
// Extract the current prefix and column name
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace in the original condition (without stripped parens)
|
||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
oldRef := currentPrefix + "." + columnName
|
||||
newRef := tableName + "." + columnName
|
||||
cond = strings.Replace(cond, oldRef, newRef, 1)
|
||||
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
|
||||
} else {
|
||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Note: We no longer add prefixes to unqualified columns here.
|
||||
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
|
||||
|
||||
validConditions = append(validConditions, cond)
|
||||
}
|
||||
@@ -167,51 +234,106 @@ func stripOuterParentheses(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
for {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
stripped, wasStripped := stripOneMatchingOuterParen(s)
|
||||
if !wasStripped {
|
||||
return s
|
||||
}
|
||||
s = stripped
|
||||
}
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s
|
||||
// stripOneOuterParentheses removes only one level of matching outer parentheses from a string
|
||||
// Unlike stripOuterParentheses, this only strips once, preserving nested parentheses
|
||||
func stripOneOuterParentheses(s string) string {
|
||||
stripped, _ := stripOneMatchingOuterParen(strings.TrimSpace(s))
|
||||
return stripped
|
||||
}
|
||||
|
||||
// stripOneMatchingOuterParen is a helper that strips one matching pair of outer parentheses
|
||||
// Returns the stripped string and a boolean indicating if stripping occurred
|
||||
func stripOneMatchingOuterParen(s string) (string, bool) {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
return s, false
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s, false
|
||||
}
|
||||
|
||||
// Strip the outer parentheses
|
||||
return strings.TrimSpace(s[1 : len(s)-1]), true
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||
func splitByAND(where string) []string {
|
||||
conditions := []string{}
|
||||
currentCondition := strings.Builder{}
|
||||
depth := 0 // Track parenthesis depth
|
||||
i := 0
|
||||
|
||||
for i < len(where) {
|
||||
ch := where[i]
|
||||
|
||||
// Track parenthesis depth
|
||||
if ch == '(' {
|
||||
depth++
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
} else if ch == ')' {
|
||||
depth--
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Only look for AND operators at depth 0 (not inside parentheses)
|
||||
if depth == 0 {
|
||||
// Check if we're at an AND operator (case-insensitive)
|
||||
// We need at least " AND " (5 chars) or " and " (5 chars)
|
||||
if i+5 <= len(where) {
|
||||
substring := where[i : i+5]
|
||||
lowerSubstring := strings.ToLower(substring)
|
||||
|
||||
if lowerSubstring == " and " {
|
||||
// Found an AND operator at the top level
|
||||
// Add the current condition to the list
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
currentCondition.Reset()
|
||||
// Skip past the AND operator
|
||||
i += 5
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s
|
||||
}
|
||||
|
||||
// Strip the outer parentheses and continue
|
||||
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||
func splitByAND(where string) []string {
|
||||
// First try uppercase AND
|
||||
conditions := strings.Split(where, " AND ")
|
||||
|
||||
// If we didn't split on uppercase, try lowercase
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " and ")
|
||||
// Not an AND operator or we're inside parentheses, just add the character
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
// If we still didn't split, try mixed case
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " And ")
|
||||
// Add the last condition
|
||||
if currentCondition.Len() > 0 {
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
}
|
||||
|
||||
return conditions
|
||||
@@ -288,6 +410,227 @@ func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
return columnMap
|
||||
}
|
||||
|
||||
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
||||
// For example: "users.status = 'active'" returns ("users", "status")
|
||||
// Returns empty strings if no table prefix is found
|
||||
// This function is parenthesis-aware and will only look for operators outside of subqueries
|
||||
func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Common SQL operators to find the column reference
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
var columnRef string
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
// We need to find the first operator that appears OUTSIDE of parentheses
|
||||
minIdx := -1
|
||||
|
||||
for _, op := range operators {
|
||||
idx := findOperatorOutsideParentheses(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if minIdx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||
}
|
||||
|
||||
// If no operator found, the whole condition might be the column reference
|
||||
if columnRef == "" {
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef = parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
if columnRef == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Check if there's a function call (contains opening parenthesis)
|
||||
openParenIdx := strings.Index(columnRef, "(")
|
||||
|
||||
if openParenIdx >= 0 {
|
||||
// There's a function call - find the FIRST dot after the opening paren
|
||||
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
|
||||
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
|
||||
if dotIdx > 0 {
|
||||
dotIdx += openParenIdx // Adjust to absolute position
|
||||
|
||||
// Extract table name (between paren and dot)
|
||||
// Find the last opening paren before this dot
|
||||
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
|
||||
table = columnRef[lastOpenParen+1 : dotIdx]
|
||||
|
||||
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
|
||||
columnStart := dotIdx + 1
|
||||
columnEnd := len(columnRef)
|
||||
|
||||
for i := columnStart; i < len(columnRef); i++ {
|
||||
ch := columnRef[i]
|
||||
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
|
||||
columnEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
column = columnRef[columnStart:columnEnd]
|
||||
|
||||
// Remove quotes from table and column if present
|
||||
table = strings.Trim(table, "`\"'")
|
||||
column = strings.Trim(column, "`\"'")
|
||||
|
||||
return table, column
|
||||
}
|
||||
}
|
||||
|
||||
// No function call - check if it contains a dot (qualified reference)
|
||||
// Use LastIndex to handle schema.table.column properly
|
||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||
table = columnRef[:dotIdx]
|
||||
column = columnRef[dotIdx+1:]
|
||||
|
||||
// Remove quotes from table and column if present
|
||||
table = strings.Trim(table, "`\"'")
|
||||
column = strings.Trim(column, "`\"'")
|
||||
|
||||
return table, column
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Unused: extractUnqualifiedColumnName extracts the column name from an unqualified condition
|
||||
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
|
||||
// "status = 'active'" returns "status"
|
||||
// nolint:unused
|
||||
func extractUnqualifiedColumnName(cond string) string {
|
||||
// Common SQL operators
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
minIdx := -1
|
||||
for _, op := range operators {
|
||||
idx := strings.Index(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
var columnRef string
|
||||
if minIdx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||
} else {
|
||||
// No operator found, might be a single column reference
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef = parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
if columnRef == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Return empty if it contains a dot (already qualified) or function call
|
||||
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return columnRef
|
||||
}
|
||||
|
||||
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
|
||||
// Uses word boundaries to avoid partial matches
|
||||
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
|
||||
// returns "table.rid_item is null"
|
||||
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
|
||||
// Use word boundary matching with Go's supported regex syntax
|
||||
// \b matches word boundaries
|
||||
escapedOld := regexp.QuoteMeta(oldRef)
|
||||
pattern := `\b` + escapedOld + `\b`
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// If regex fails, fall back to simple string replacement
|
||||
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
|
||||
return strings.Replace(cond, oldRef, newRef, 1)
|
||||
}
|
||||
|
||||
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
|
||||
result := cond
|
||||
matches := re.FindAllStringIndex(cond, -1)
|
||||
|
||||
// Process matches in reverse order to maintain correct indices
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
match := matches[i]
|
||||
start := match[0]
|
||||
|
||||
// Check if preceded by a dot (already qualified)
|
||||
if start > 0 && cond[start-1] == '.' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Replace this occurrence
|
||||
result = result[:start] + newRef + result[match[1]:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||
depth := 0
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
|
||||
// Track quote state (operators inside quotes should be ignored)
|
||||
if ch == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if ch == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if we're inside quotes
|
||||
if inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track parenthesis depth
|
||||
switch ch {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
}
|
||||
|
||||
// Only look for the operator when we're outside parentheses (depth == 0)
|
||||
if depth == 0 {
|
||||
// Check if the operator starts at this position
|
||||
if i+len(operator) <= len(s) {
|
||||
if s[i:i+len(operator)] == operator {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
@@ -296,3 +639,173 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
}
|
||||
return validColumns[strings.ToLower(columnName)]
|
||||
}
|
||||
|
||||
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
|
||||
// This function only prefixes simple column references and skips:
|
||||
// - Columns already having a table prefix (containing a dot)
|
||||
// - Columns inside function calls or expressions (inside parentheses)
|
||||
// - Columns inside subqueries
|
||||
// - Columns that don't exist in the table (validation via model registry)
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
|
||||
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
|
||||
// - "users.status = 'active'" -> unchanged (already has prefix)
|
||||
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
|
||||
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause to process
|
||||
// - tableName: The table name to use as prefix
|
||||
//
|
||||
// Returns:
|
||||
// - The WHERE clause with table prefixes added to appropriate and valid columns
|
||||
func AddTablePrefixToColumns(where string, tableName string) string {
|
||||
if where == "" || tableName == "" {
|
||||
return where
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Get valid columns from the model registry for validation
|
||||
validColumns := getValidColumnsForTable(tableName)
|
||||
|
||||
// Split by AND to handle multiple conditions (parenthesis-aware)
|
||||
conditions := splitByAND(where)
|
||||
prefixedConditions := make([]string, 0, len(conditions))
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process this condition to add table prefix if appropriate
|
||||
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
|
||||
prefixedConditions = append(prefixedConditions, processedCond)
|
||||
}
|
||||
|
||||
if len(prefixedConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.Join(prefixedConditions, " AND ")
|
||||
}
|
||||
|
||||
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
|
||||
// Returns the condition unchanged if:
|
||||
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
|
||||
// - The column reference is inside a function call
|
||||
// - The column already has a table prefix
|
||||
// - No valid column reference is found
|
||||
// - The column doesn't exist in the table (when validColumns is provided)
|
||||
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
|
||||
// Strip one level of outer grouping parentheses to get to the actual condition
|
||||
strippedCond := stripOneOuterParentheses(cond)
|
||||
|
||||
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
|
||||
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
|
||||
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
|
||||
return cond
|
||||
}
|
||||
|
||||
// After stripping outer parentheses, check if there are multiple AND-separated conditions
|
||||
// at the top level. If so, split and process each separately to avoid incorrectly
|
||||
// treating "true AND status" as a single column name.
|
||||
subConditions := splitByAND(strippedCond)
|
||||
if len(subConditions) > 1 {
|
||||
// Multiple conditions found - process each separately
|
||||
logger.Debug("Found %d sub-conditions after stripping parentheses, processing separately", len(subConditions))
|
||||
processedConditions := make([]string, 0, len(subConditions))
|
||||
for _, subCond := range subConditions {
|
||||
// Recursively process each sub-condition
|
||||
processed := addPrefixToSingleCondition(subCond, tableName, validColumns)
|
||||
processedConditions = append(processedConditions, processed)
|
||||
}
|
||||
result := strings.Join(processedConditions, " AND ")
|
||||
// Preserve original outer parentheses if they existed
|
||||
if cond != strippedCond {
|
||||
result = "(" + result + ")"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// If we stripped parentheses and still have more parentheses, recursively process
|
||||
if cond != strippedCond && strings.HasPrefix(strippedCond, "(") && strings.HasSuffix(strippedCond, ")") {
|
||||
// Recursively handle nested parentheses
|
||||
processed := addPrefixToSingleCondition(strippedCond, tableName, validColumns)
|
||||
return "(" + processed + ")"
|
||||
}
|
||||
|
||||
// Extract the left side of the comparison (before the operator)
|
||||
columnRef := extractLeftSideOfComparison(strippedCond)
|
||||
if columnRef == "" {
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it already has a prefix (contains a dot)
|
||||
if strings.Contains(columnRef, ".") {
|
||||
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it's a function call or expression (contains parentheses)
|
||||
if strings.Contains(columnRef, "(") {
|
||||
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Validate that the column exists in the table (if we have column info)
|
||||
if !isValidColumn(columnRef, validColumns) {
|
||||
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
|
||||
return cond
|
||||
}
|
||||
|
||||
// It's a simple unqualified column reference that exists in the table - add the table prefix
|
||||
newRef := tableName + "." + columnRef
|
||||
result := qualifyColumnInCondition(cond, columnRef, newRef)
|
||||
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
|
||||
// This is used to identify the column reference that may need a table prefix.
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" returns "status"
|
||||
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
|
||||
// - "priority > 5" returns "priority"
|
||||
//
|
||||
// Returns empty string if no operator is found.
|
||||
func extractLeftSideOfComparison(cond string) string {
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||
|
||||
// Find the first operator outside of parentheses and quotes
|
||||
minIdx := -1
|
||||
for _, op := range operators {
|
||||
idx := findOperatorOutsideParentheses(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if minIdx > 0 {
|
||||
leftSide := strings.TrimSpace(cond[:minIdx])
|
||||
// Remove any surrounding quotes
|
||||
leftSide = strings.Trim(leftSide, "`\"'")
|
||||
return leftSide
|
||||
}
|
||||
|
||||
// No operator found - might be a boolean column
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef := strings.Trim(parts[0], "`\"'")
|
||||
// Make sure it's not a SQL keyword
|
||||
if !IsSQLKeyword(strings.ToLower(columnRef)) {
|
||||
return columnRef
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
@@ -32,25 +33,37 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "valid condition with parentheses",
|
||||
name: "valid condition with parentheses - prefix added to prevent ambiguity",
|
||||
where: "(status = 'active')",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mixed trivial and valid conditions",
|
||||
name: "mixed trivial and valid conditions - prefix added",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "condition already with table prefix",
|
||||
name: "condition with correct table prefix - unchanged",
|
||||
where: "users.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions",
|
||||
name: "condition with incorrect table prefix - fixed",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple conditions with incorrect prefix - fixed",
|
||||
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions without prefix - prefixes added",
|
||||
where: "status = 'active' AND age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
@@ -67,11 +80,68 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mixed correct and incorrect prefixes",
|
||||
where: "users.status = 'active' AND wrong_table.age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "mixed case AND operators",
|
||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
|
||||
},
|
||||
{
|
||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
tableName: "users",
|
||||
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
},
|
||||
{
|
||||
name: "dangerous DELETE keyword - blocked",
|
||||
where: "status = 'active'; DELETE FROM users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous UPDATE keyword - blocked",
|
||||
where: "1=1; UPDATE users SET admin = true",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous TRUNCATE keyword - blocked",
|
||||
where: "status = 'active' OR TRUNCATE TABLE users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "dangerous DROP keyword - blocked",
|
||||
where: "status = 'active'; DROP TABLE users",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "subquery with table alias should not be modified",
|
||||
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
|
||||
tableName: "apiprovider",
|
||||
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
|
||||
},
|
||||
{
|
||||
name: "complex subquery with AND and multiple operators",
|
||||
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
||||
tableName: "apiprovider",
|
||||
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
// Then sanitize the where clause
|
||||
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
@@ -120,6 +190,11 @@ func TestStripOuterParentheses(t *testing.T) {
|
||||
input: " ( true ) ",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "complex sub query",
|
||||
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
|
||||
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -159,6 +234,224 @@ func TestIsTrivialCondition(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTableAndColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedTable string
|
||||
expectedCol string
|
||||
}{
|
||||
{
|
||||
name: "qualified column with equals",
|
||||
input: "users.status = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "qualified column with greater than",
|
||||
input: "users.age > 18",
|
||||
expectedTable: "users",
|
||||
expectedCol: "age",
|
||||
},
|
||||
{
|
||||
name: "qualified column with LIKE",
|
||||
input: "users.name LIKE '%john%'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "name",
|
||||
},
|
||||
{
|
||||
name: "qualified column with IN",
|
||||
input: "users.status IN ('active', 'pending')",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "unqualified column",
|
||||
input: "status = 'active'",
|
||||
expectedTable: "",
|
||||
expectedCol: "",
|
||||
},
|
||||
{
|
||||
name: "qualified with backticks",
|
||||
input: "`users`.`status` = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "schema.table.column reference",
|
||||
input: "public.users.status = 'active'",
|
||||
expectedTable: "public.users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectedTable: "",
|
||||
expectedCol: "",
|
||||
},
|
||||
{
|
||||
name: "function call with table.column - ifblnk",
|
||||
input: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "function call with table.column - coalesce",
|
||||
input: "coalesce(users.age, 0) = 25",
|
||||
expectedTable: "users",
|
||||
expectedCol: "age",
|
||||
},
|
||||
{
|
||||
name: "nested function calls",
|
||||
input: "upper(trim(users.name)) = 'JOHN'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "name",
|
||||
},
|
||||
{
|
||||
name: "function with multiple args and table.column",
|
||||
input: "substring(users.email, 1, 5) = 'admin'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "email",
|
||||
},
|
||||
{
|
||||
name: "cast function with table.column",
|
||||
input: "cast(orders.total as decimal) > 100",
|
||||
expectedTable: "orders",
|
||||
expectedCol: "total",
|
||||
},
|
||||
{
|
||||
name: "complex nested functions",
|
||||
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "function with multiple table.column refs (extracts first)",
|
||||
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "created_at",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
table, col := extractTableAndColumn(tt.input)
|
||||
if table != tt.expectedTable || col != tt.expectedCol {
|
||||
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
|
||||
tt.input, table, col, tt.expectedTable, tt.expectedCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
options *RequestOptions
|
||||
expected string
|
||||
addPrefix bool
|
||||
}{
|
||||
{
|
||||
name: "preload relation prefix is preserved",
|
||||
where: "Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "Department.name = 'Engineering'",
|
||||
},
|
||||
{
|
||||
name: "multiple preload relations - all preserved",
|
||||
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
{Relation: "Manager"},
|
||||
},
|
||||
},
|
||||
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mix of main table and preload relation",
|
||||
where: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
},
|
||||
{
|
||||
name: "incorrect prefix fixed when not a preload relation",
|
||||
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
},
|
||||
|
||||
{
|
||||
name: "Function Call with correct table prefix - unchanged",
|
||||
where: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
tableName: "users",
|
||||
options: nil,
|
||||
expected: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
},
|
||||
{
|
||||
name: "no options provided - works as before",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
options: nil,
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "empty preload list - works as before",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{Preload: []PreloadOption{}},
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
|
||||
{
|
||||
name: "complex where clause with subquery and preload",
|
||||
where: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (rid_parentmastertaskitem is null)`,
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (mastertaskitem.rid_parentmastertaskitem is null)`,
|
||||
addPrefix: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result string
|
||||
prefixedWhere := tt.where
|
||||
if tt.addPrefix {
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere = AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
}
|
||||
// Then sanitize the where clause
|
||||
if tt.options != nil {
|
||||
result = SanitizeWhereClause(prefixedWhere, tt.tableName, tt.options)
|
||||
} else {
|
||||
result = SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test model for model-aware sanitization tests
|
||||
type MasterTask struct {
|
||||
ID int `bun:"id,pk"`
|
||||
@@ -167,6 +460,131 @@ type MasterTask struct {
|
||||
UserID int `bun:"user_id"`
|
||||
}
|
||||
|
||||
func TestSplitByAND(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "uppercase AND",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expected: []string{"status = 'active'", "age > 18"},
|
||||
},
|
||||
{
|
||||
name: "lowercase and",
|
||||
input: "status = 'active' and age > 18",
|
||||
expected: []string{"status = 'active'", "age > 18"},
|
||||
},
|
||||
{
|
||||
name: "mixed case AND",
|
||||
input: "status = 'active' AND age > 18 and name = 'John'",
|
||||
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
|
||||
},
|
||||
{
|
||||
name: "single condition",
|
||||
input: "status = 'active'",
|
||||
expected: []string{"status = 'active'"},
|
||||
},
|
||||
{
|
||||
name: "multiple uppercase AND",
|
||||
input: "a = 1 AND b = 2 AND c = 3",
|
||||
expected: []string{"a = 1", "b = 2", "c = 3"},
|
||||
},
|
||||
{
|
||||
name: "multiple case subquery",
|
||||
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitByAND(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
|
||||
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWhereClauseSecurity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "safe WHERE clause",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "safe subquery",
|
||||
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "DELETE keyword",
|
||||
input: "status = 'active'; DELETE FROM users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "UPDATE keyword",
|
||||
input: "1=1; UPDATE users SET admin = true",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "TRUNCATE keyword",
|
||||
input: "status = 'active' OR TRUNCATE TABLE users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "DROP keyword",
|
||||
input: "status = 'active'; DROP TABLE users",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "INSERT keyword",
|
||||
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "ALTER keyword",
|
||||
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "CREATE keyword",
|
||||
input: "1=1; CREATE TABLE malicious (id INT)",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty clause",
|
||||
input: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateWhereClauseSecurity(tt.input)
|
||||
if tt.expectError && err == nil {
|
||||
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
// Register the test model
|
||||
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||
@@ -182,34 +600,52 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid column gets prefixed",
|
||||
name: "valid column without prefix - no prefix added",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns without prefix - no prefix added",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active' AND user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "incorrect table prefix on valid column - fixed",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns get prefixed",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
name: "incorrect prefix on invalid column - not fixed",
|
||||
where: "wrong_table.invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "invalid column does not get prefixed",
|
||||
where: "invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "invalid_column = 'value'",
|
||||
expected: "wrong_table.invalid_column = 'value'",
|
||||
},
|
||||
{
|
||||
name: "mix of valid and trivial conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column - no prefix added",
|
||||
where: "(status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "correct prefix - unchanged",
|
||||
where: "mastertask.status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column",
|
||||
where: "(status = 'active')",
|
||||
name: "multiple conditions with mixed prefixes",
|
||||
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -222,3 +658,76 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Parentheses with true AND condition - should not prefix true",
|
||||
where: "(true AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Parentheses with multiple conditions including true",
|
||||
where: "(true AND status = 'active' AND id > 5)",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND mastertask.status = 'active' AND mastertask.id > 5)",
|
||||
},
|
||||
{
|
||||
name: "Nested parentheses with true",
|
||||
where: "((true AND status = 'active'))",
|
||||
tableName: "mastertask",
|
||||
expected: "((true AND mastertask.status = 'active'))",
|
||||
},
|
||||
{
|
||||
name: "Mixed: false AND valid conditions",
|
||||
where: "(false AND name = 'test')",
|
||||
tableName: "mastertask",
|
||||
expected: "(false AND mastertask.name = 'test')",
|
||||
},
|
||||
{
|
||||
name: "Mixed: null AND valid conditions",
|
||||
where: "(null AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(null AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Multiple true conditions in parentheses",
|
||||
where: "(true AND true AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND true AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Simple true without parens - should not prefix",
|
||||
where: "true",
|
||||
tableName: "mastertask",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Simple condition without parens - should prefix",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "Unregistered table with true - should not prefix true",
|
||||
where: "(true AND status = 'active')",
|
||||
tableName: "unregistered_table",
|
||||
expected: "(true AND unregistered_table.status = 'active')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
@@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
|
||||
// Filter preload sort columns
|
||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||
for _, sort := range preload.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column)
|
||||
}
|
||||
}
|
||||
filteredPreload.Sort = validPreloadSorts
|
||||
|
||||
validPreloads = append(validPreloads, filteredPreload)
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
@@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
return filtered
|
||||
}
|
||||
|
||||
// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe
|
||||
// and doesn't contain SQL injection attempts or dangerous commands
|
||||
func IsSafeSortExpression(expr string) bool {
|
||||
if expr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression must be enclosed in brackets
|
||||
expr = strings.TrimSpace(expr)
|
||||
if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove outer brackets for content validation
|
||||
expr = expr[1 : len(expr)-1]
|
||||
expr = strings.TrimSpace(expr)
|
||||
|
||||
// Convert to lowercase for checking dangerous keywords
|
||||
exprLower := strings.ToLower(expr)
|
||||
|
||||
// Check for dangerous SQL commands that should never be in a sort expression
|
||||
dangerousKeywords := []string{
|
||||
"drop ", "delete ", "insert ", "update ", "alter ", "create ",
|
||||
"truncate ", "exec ", "execute ", "grant ", "revoke ",
|
||||
"into ", "values ", "set ", "shutdown", "xp_",
|
||||
}
|
||||
|
||||
for _, keyword := range dangerousKeywords {
|
||||
if strings.Contains(exprLower, keyword) {
|
||||
logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for SQL comment attempts
|
||||
if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") {
|
||||
logger.Warn("SQL comment detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for semicolon (command separator)
|
||||
if strings.Contains(expr, ";") {
|
||||
logger.Warn("Command separator (;) detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression appears safe
|
||||
return true
|
||||
}
|
||||
|
||||
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||
func (v *ColumnValidator) GetValidColumns() []string {
|
||||
columns := make([]string, 0, len(v.validColumns))
|
||||
|
||||
@@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) {
|
||||
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeSortExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
shouldPass bool
|
||||
}{
|
||||
// Safe expressions
|
||||
{"Valid subquery", "(SELECT MAX(price) FROM products)", true},
|
||||
{"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true},
|
||||
{"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true},
|
||||
{"Valid function", "(COALESCE(discount, 0))", true},
|
||||
|
||||
// Dangerous expressions - SQL injection attempts
|
||||
{"DROP TABLE attempt", "(id); DROP TABLE users; --", false},
|
||||
{"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false},
|
||||
{"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false},
|
||||
{"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false},
|
||||
{"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false},
|
||||
{"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false},
|
||||
|
||||
// Comment injection
|
||||
{"SQL comment dash", "(id) -- malicious comment", false},
|
||||
{"SQL comment block start", "(id) /* comment", false},
|
||||
{"SQL comment block end", "(id) comment */", false},
|
||||
|
||||
// Semicolon attempts
|
||||
{"Semicolon separator", "(id); SELECT * FROM passwords", false},
|
||||
|
||||
// Empty/invalid
|
||||
{"Empty string", "", false},
|
||||
{"Just brackets", "()", true}, // Empty but technically valid structure
|
||||
{"No brackets", "id", false}, // Must have brackets for expressions
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsSafeSortExpression(tt.expression)
|
||||
if result != tt.shouldPass {
|
||||
t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Sort: []SortOption{
|
||||
{Column: "id", Direction: "ASC"}, // Valid column
|
||||
{Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression
|
||||
{Column: "name", Direction: "ASC"}, // Valid column
|
||||
{Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression
|
||||
{Column: "invalid_col", Direction: "ASC"}, // Invalid column
|
||||
{Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Should keep: id, safe expression, name, another safe expression
|
||||
// Should remove: dangerous expression, invalid column
|
||||
expectedCount := 4
|
||||
if len(filtered.Sort) != expectedCount {
|
||||
t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort))
|
||||
}
|
||||
|
||||
// Verify the kept options
|
||||
if filtered.Sort[0].Column != "id" {
|
||||
t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column)
|
||||
}
|
||||
if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" {
|
||||
t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column)
|
||||
}
|
||||
if filtered.Sort[2].Column != "name" {
|
||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,15 @@ import "time"
|
||||
|
||||
// Config represents the complete application configuration
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Tracing TracingConfig `mapstructure:"tracing"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Tracing TracingConfig `mapstructure:"tracing"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||
DBManager DBManagerConfig `mapstructure:"dbmanager"`
|
||||
}
|
||||
|
||||
// ServerConfig holds server-related configuration
|
||||
@@ -74,7 +76,63 @@ type CORSConfig struct {
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds database configuration (primarily for testing)
|
||||
type DatabaseConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
// ErrorTrackingConfig holds error tracking configuration
|
||||
type ErrorTrackingConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Provider string `mapstructure:"provider"` // sentry, noop
|
||||
DSN string `mapstructure:"dsn"` // Sentry DSN
|
||||
Environment string `mapstructure:"environment"` // e.g., production, staging, development
|
||||
Release string `mapstructure:"release"` // Application version/release
|
||||
Debug bool `mapstructure:"debug"` // Enable debug mode
|
||||
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
||||
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
||||
}
|
||||
|
||||
// EventBrokerConfig contains configuration for the event broker
|
||||
type EventBrokerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Provider string `mapstructure:"provider"` // memory, redis, nats, database
|
||||
Mode string `mapstructure:"mode"` // sync, async
|
||||
WorkerCount int `mapstructure:"worker_count"`
|
||||
BufferSize int `mapstructure:"buffer_size"`
|
||||
InstanceID string `mapstructure:"instance_id"`
|
||||
Redis EventBrokerRedisConfig `mapstructure:"redis"`
|
||||
NATS EventBrokerNATSConfig `mapstructure:"nats"`
|
||||
Database EventBrokerDatabaseConfig `mapstructure:"database"`
|
||||
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
|
||||
}
|
||||
|
||||
// EventBrokerRedisConfig contains Redis-specific configuration
|
||||
type EventBrokerRedisConfig struct {
|
||||
StreamName string `mapstructure:"stream_name"`
|
||||
ConsumerGroup string `mapstructure:"consumer_group"`
|
||||
MaxLen int64 `mapstructure:"max_len"`
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
// EventBrokerNATSConfig contains NATS-specific configuration
|
||||
type EventBrokerNATSConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
StreamName string `mapstructure:"stream_name"`
|
||||
Subjects []string `mapstructure:"subjects"`
|
||||
Storage string `mapstructure:"storage"` // file, memory
|
||||
MaxAge time.Duration `mapstructure:"max_age"`
|
||||
}
|
||||
|
||||
// EventBrokerDatabaseConfig contains database provider configuration
|
||||
type EventBrokerDatabaseConfig struct {
|
||||
TableName string `mapstructure:"table_name"`
|
||||
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
|
||||
PollInterval time.Duration `mapstructure:"poll_interval"`
|
||||
}
|
||||
|
||||
// EventBrokerRetryPolicyConfig contains retry policy configuration
|
||||
type EventBrokerRetryPolicyConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
InitialDelay time.Duration `mapstructure:"initial_delay"`
|
||||
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||
}
|
||||
|
||||
107
pkg/config/dbmanager.go
Normal file
107
pkg/config/dbmanager.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DBManagerConfig contains configuration for the database connection manager
|
||||
type DBManagerConfig struct {
|
||||
// DefaultConnection is the name of the default connection to use
|
||||
DefaultConnection string `mapstructure:"default_connection"`
|
||||
|
||||
// Connections is a map of connection name to connection configuration
|
||||
Connections map[string]DBConnectionConfig `mapstructure:"connections"`
|
||||
|
||||
// Global connection pool defaults
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
|
||||
|
||||
// Retry policy
|
||||
RetryAttempts int `mapstructure:"retry_attempts"`
|
||||
RetryDelay time.Duration `mapstructure:"retry_delay"`
|
||||
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||
|
||||
// Health checks
|
||||
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
|
||||
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
|
||||
}
|
||||
|
||||
// DBConnectionConfig defines configuration for a single database connection
|
||||
type DBConnectionConfig struct {
|
||||
// Name is the unique name of this connection
|
||||
Name string `mapstructure:"name"`
|
||||
|
||||
// Type is the database type (postgres, sqlite, mssql, mongodb)
|
||||
Type string `mapstructure:"type"`
|
||||
|
||||
// DSN is the complete Data Source Name / connection string
|
||||
// If provided, this takes precedence over individual connection parameters
|
||||
DSN string `mapstructure:"dsn"`
|
||||
|
||||
// Connection parameters (used if DSN is not provided)
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database string `mapstructure:"database"`
|
||||
|
||||
// PostgreSQL/MSSQL specific
|
||||
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
|
||||
Schema string `mapstructure:"schema"` // Default schema
|
||||
|
||||
// SQLite specific
|
||||
FilePath string `mapstructure:"filepath"`
|
||||
|
||||
// MongoDB specific
|
||||
AuthSource string `mapstructure:"auth_source"`
|
||||
ReplicaSet string `mapstructure:"replica_set"`
|
||||
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
|
||||
|
||||
// Connection pool settings (overrides global defaults)
|
||||
MaxOpenConns *int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns *int `mapstructure:"max_idle_conns"`
|
||||
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
|
||||
|
||||
// Timeouts
|
||||
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
|
||||
QueryTimeout time.Duration `mapstructure:"query_timeout"`
|
||||
|
||||
// Features
|
||||
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||
EnableMetrics bool `mapstructure:"enable_metrics"`
|
||||
EnableLogging bool `mapstructure:"enable_logging"`
|
||||
|
||||
// DefaultORM specifies which ORM to use for the Database() method
|
||||
// Options: "bun", "gorm", "native"
|
||||
DefaultORM string `mapstructure:"default_orm"`
|
||||
|
||||
// Tags for organization and filtering
|
||||
Tags map[string]string `mapstructure:"tags"`
|
||||
}
|
||||
|
||||
// ToManagerConfig converts config.DBManagerConfig to dbmanager.ManagerConfig
|
||||
// This is used to avoid circular dependencies
|
||||
func (c *DBManagerConfig) ToManagerConfig() interface{} {
|
||||
// This will be implemented in the dbmanager package
|
||||
// to convert from config types to dbmanager types
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate validates the DBManager configuration
|
||||
func (c *DBManagerConfig) Validate() error {
|
||||
if len(c.Connections) == 0 {
|
||||
return fmt.Errorf("at least one connection must be configured")
|
||||
}
|
||||
|
||||
if c.DefaultConnection != "" {
|
||||
if _, ok := c.Connections[c.DefaultConnection]; !ok {
|
||||
return fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -165,4 +165,39 @@ func setDefaults(v *viper.Viper) {
|
||||
|
||||
// Database defaults
|
||||
v.SetDefault("database.url", "")
|
||||
|
||||
// Event Broker defaults
|
||||
v.SetDefault("event_broker.enabled", false)
|
||||
v.SetDefault("event_broker.provider", "memory")
|
||||
v.SetDefault("event_broker.mode", "async")
|
||||
v.SetDefault("event_broker.worker_count", 10)
|
||||
v.SetDefault("event_broker.buffer_size", 1000)
|
||||
v.SetDefault("event_broker.instance_id", "")
|
||||
|
||||
// Event Broker - Redis defaults
|
||||
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
|
||||
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
|
||||
v.SetDefault("event_broker.redis.max_len", 10000)
|
||||
v.SetDefault("event_broker.redis.host", "localhost")
|
||||
v.SetDefault("event_broker.redis.port", 6379)
|
||||
v.SetDefault("event_broker.redis.password", "")
|
||||
v.SetDefault("event_broker.redis.db", 0)
|
||||
|
||||
// Event Broker - NATS defaults
|
||||
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
|
||||
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
|
||||
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
|
||||
v.SetDefault("event_broker.nats.storage", "file")
|
||||
v.SetDefault("event_broker.nats.max_age", "24h")
|
||||
|
||||
// Event Broker - Database defaults
|
||||
v.SetDefault("event_broker.database.table_name", "events")
|
||||
v.SetDefault("event_broker.database.channel", "resolvespec_events")
|
||||
v.SetDefault("event_broker.database.poll_interval", "1s")
|
||||
|
||||
// Event Broker - Retry Policy defaults
|
||||
v.SetDefault("event_broker.retry_policy.max_retries", 3)
|
||||
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||
}
|
||||
|
||||
531
pkg/dbmanager/README.md
Normal file
531
pkg/dbmanager/README.md
Normal file
@@ -0,0 +1,531 @@
|
||||
# Database Connection Manager (dbmanager)
|
||||
|
||||
A comprehensive database connection manager for Go that provides centralized management of multiple named database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Named Connections**: Manage multiple database connections with names like `primary`, `analytics`, `cache-db`
|
||||
- **Multi-Database Support**: PostgreSQL, SQLite, Microsoft SQL Server, and MongoDB
|
||||
- **Multi-ORM Access**: Each SQL connection provides access through:
|
||||
- **Bun ORM** - Modern, lightweight ORM
|
||||
- **GORM** - Popular Go ORM
|
||||
- **Native** - Standard library `*sql.DB`
|
||||
- All three share the same underlying connection pool
|
||||
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||
- **Production-Ready Features**:
|
||||
- Automatic health checks and reconnection
|
||||
- Prometheus metrics
|
||||
- Connection pooling with configurable limits
|
||||
- Retry logic with exponential backoff
|
||||
- Graceful shutdown
|
||||
- OpenTelemetry tracing support
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/bitechdev/ResolveSpec/pkg/dbmanager
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configuration
|
||||
|
||||
Create a configuration file (e.g., `config.yaml`):
|
||||
|
||||
```yaml
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
|
||||
# Global connection pool defaults
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 5
|
||||
conn_max_lifetime: 30m
|
||||
conn_max_idle_time: 5m
|
||||
|
||||
# Retry configuration
|
||||
retry_attempts: 3
|
||||
retry_delay: 1s
|
||||
retry_max_delay: 10s
|
||||
|
||||
# Health checks
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
|
||||
connections:
|
||||
# Primary PostgreSQL connection
|
||||
primary:
|
||||
type: postgres
|
||||
host: localhost
|
||||
port: 5432
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: myapp
|
||||
sslmode: disable
|
||||
default_orm: bun
|
||||
enable_metrics: true
|
||||
enable_tracing: true
|
||||
enable_logging: true
|
||||
|
||||
# Read replica for analytics
|
||||
analytics:
|
||||
type: postgres
|
||||
dsn: "postgres://readonly:pass@analytics:5432/analytics"
|
||||
default_orm: bun
|
||||
enable_metrics: true
|
||||
|
||||
# SQLite cache
|
||||
cache-db:
|
||||
type: sqlite
|
||||
filepath: /var/lib/app/cache.db
|
||||
max_open_conns: 1
|
||||
|
||||
# MongoDB for documents
|
||||
documents:
|
||||
type: mongodb
|
||||
host: localhost
|
||||
port: 27017
|
||||
database: documents
|
||||
user: mongouser
|
||||
password: mongopass
|
||||
auth_source: admin
|
||||
enable_metrics: true
|
||||
```
|
||||
|
||||
### 2. Initialize Manager
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
if err := cfgMgr.Load(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cfg, _ := cfgMgr.GetConfig()
|
||||
|
||||
// Create database manager
|
||||
mgr, err := dbmanager.NewManager(cfg.DBManager)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer mgr.Close()
|
||||
|
||||
// Connect all databases
|
||||
ctx := context.Background()
|
||||
if err := mgr.Connect(ctx); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Your application code here...
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Use Database Connections
|
||||
|
||||
#### Get Default Database
|
||||
|
||||
```go
|
||||
// Get the default database (as configured common.Database interface)
|
||||
db, err := mgr.GetDefaultDatabase()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Use it with any query
|
||||
var users []User
|
||||
err = db.NewSelect().
|
||||
Model(&users).
|
||||
Where("active = ?", true).
|
||||
Scan(ctx, &users)
|
||||
```
|
||||
|
||||
#### Get Named Connection with Specific ORM
|
||||
|
||||
```go
|
||||
// Get primary connection
|
||||
primary, err := mgr.Get("primary")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Use with Bun
|
||||
bunDB, err := primary.Bun()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||
|
||||
// Use with GORM (same underlying connection!)
|
||||
gormDB, err := primary.GORM()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
gormDB.Where("active = ?", true).Find(&users)
|
||||
|
||||
// Use native *sql.DB
|
||||
nativeDB, err := primary.Native()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||
```
|
||||
|
||||
#### Use MongoDB
|
||||
|
||||
```go
|
||||
// Get MongoDB connection
|
||||
docs, err := mgr.Get("documents")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
mongoClient, err := docs.MongoDB()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
collection := mongoClient.Database("documents").Collection("articles")
|
||||
// Use MongoDB driver...
|
||||
```
|
||||
|
||||
#### Change Default Database
|
||||
|
||||
```go
|
||||
// Switch to analytics database as default
|
||||
err := mgr.SetDefaultDatabase("analytics")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Now GetDefaultDatabase() returns the analytics connection
|
||||
db, _ := mgr.GetDefaultDatabase()
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Manager Configuration
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `default_connection` | string | "" | Name of the default connection |
|
||||
| `connections` | map | {} | Map of connection name to ConnectionConfig |
|
||||
| `max_open_conns` | int | 25 | Global default for max open connections |
|
||||
| `max_idle_conns` | int | 5 | Global default for max idle connections |
|
||||
| `conn_max_lifetime` | duration | 30m | Global default for connection max lifetime |
|
||||
| `conn_max_idle_time` | duration | 5m | Global default for connection max idle time |
|
||||
| `retry_attempts` | int | 3 | Number of connection retry attempts |
|
||||
| `retry_delay` | duration | 1s | Initial retry delay |
|
||||
| `retry_max_delay` | duration | 10s | Maximum retry delay |
|
||||
| `health_check_interval` | duration | 30s | Interval between health checks |
|
||||
| `enable_auto_reconnect` | bool | true | Auto-reconnect on health check failure |
|
||||
|
||||
### Connection Configuration
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `name` | string | Unique connection name |
|
||||
| `type` | string | Database type: `postgres`, `sqlite`, `mssql`, `mongodb` |
|
||||
| `dsn` | string | Complete connection string (overrides individual params) |
|
||||
| `host` | string | Database host |
|
||||
| `port` | int | Database port |
|
||||
| `user` | string | Username |
|
||||
| `password` | string | Password |
|
||||
| `database` | string | Database name |
|
||||
| `sslmode` | string | SSL mode (postgres/mssql): `disable`, `require`, etc. |
|
||||
| `schema` | string | Default schema (postgres/mssql) |
|
||||
| `filepath` | string | File path (sqlite only) |
|
||||
| `auth_source` | string | Auth source (mongodb) |
|
||||
| `replica_set` | string | Replica set name (mongodb) |
|
||||
| `read_preference` | string | Read preference (mongodb): `primary`, `secondary`, etc. |
|
||||
| `max_open_conns` | int | Override global max open connections |
|
||||
| `max_idle_conns` | int | Override global max idle connections |
|
||||
| `conn_max_lifetime` | duration | Override global connection max lifetime |
|
||||
| `conn_max_idle_time` | duration | Override global connection max idle time |
|
||||
| `connect_timeout` | duration | Connection timeout (default: 10s) |
|
||||
| `query_timeout` | duration | Query timeout (default: 30s) |
|
||||
| `enable_tracing` | bool | Enable OpenTelemetry tracing |
|
||||
| `enable_metrics` | bool | Enable Prometheus metrics |
|
||||
| `enable_logging` | bool | Enable connection logging |
|
||||
| `default_orm` | string | Default ORM for Database(): `bun`, `gorm`, `native` |
|
||||
| `tags` | map[string]string | Custom tags for filtering/organization |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Health Checks
|
||||
|
||||
```go
|
||||
// Manual health check
|
||||
if err := mgr.HealthCheck(ctx); err != nil {
|
||||
log.Printf("Health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Per-connection health check
|
||||
primary, _ := mgr.Get("primary")
|
||||
if err := primary.HealthCheck(ctx); err != nil {
|
||||
log.Printf("Primary connection unhealthy: %v", err)
|
||||
|
||||
// Manual reconnect
|
||||
if err := primary.Reconnect(ctx); err != nil {
|
||||
log.Printf("Reconnection failed: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Connection Statistics
|
||||
|
||||
```go
|
||||
// Get overall statistics
|
||||
stats := mgr.Stats()
|
||||
fmt.Printf("Total connections: %d\n", stats.TotalConnections)
|
||||
fmt.Printf("Healthy: %d, Unhealthy: %d\n", stats.HealthyCount, stats.UnhealthyCount)
|
||||
|
||||
// Per-connection stats
|
||||
for name, connStats := range stats.ConnectionStats {
|
||||
fmt.Printf("%s: %d open, %d in use, %d idle\n",
|
||||
name,
|
||||
connStats.OpenConnections,
|
||||
connStats.InUse,
|
||||
connStats.Idle)
|
||||
}
|
||||
|
||||
// Individual connection stats
|
||||
primary, _ := mgr.Get("primary")
|
||||
stats := primary.Stats()
|
||||
fmt.Printf("Wait count: %d, Wait duration: %v\n",
|
||||
stats.WaitCount,
|
||||
stats.WaitDuration)
|
||||
```
|
||||
|
||||
### Prometheus Metrics
|
||||
|
||||
The package automatically exports Prometheus metrics:
|
||||
|
||||
- `dbmanager_connections_total` - Total configured connections by type
|
||||
- `dbmanager_connection_status` - Connection health status (1=healthy, 0=unhealthy)
|
||||
- `dbmanager_connection_pool_size` - Connection pool statistics by state
|
||||
- `dbmanager_connection_wait_count` - Times connections waited for availability
|
||||
- `dbmanager_connection_wait_duration_seconds` - Total wait duration
|
||||
- `dbmanager_health_check_duration_seconds` - Health check execution time
|
||||
- `dbmanager_reconnect_attempts_total` - Reconnection attempts and results
|
||||
- `dbmanager_connection_lifetime_closed_total` - Connections closed due to max lifetime
|
||||
- `dbmanager_connection_idle_closed_total` - Connections closed due to max idle time
|
||||
|
||||
Metrics are automatically updated during health checks. To manually publish metrics:
|
||||
|
||||
```go
|
||||
if mgr, ok := mgr.(*connectionManager); ok {
|
||||
mgr.PublishMetrics()
|
||||
}
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Single Connection Pool, Multiple ORMs
|
||||
|
||||
A key design principle is that Bun, GORM, and Native all wrap the **same underlying `*sql.DB`** connection pool:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ SQL Connection │
|
||||
├─────────────────────────────────────┤
|
||||
│ ┌─────────┐ ┌──────┐ ┌────────┐ │
|
||||
│ │ Bun │ │ GORM │ │ Native │ │
|
||||
│ └────┬────┘ └───┬──┘ └───┬────┘ │
|
||||
│ │ │ │ │
|
||||
│ └───────────┴─────────┘ │
|
||||
│ *sql.DB │
|
||||
│ (single pool) │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- No connection duplication
|
||||
- Consistent pool limits across all ORMs
|
||||
- Unified connection statistics
|
||||
- Lower resource usage
|
||||
|
||||
### Provider Pattern
|
||||
|
||||
Each database type has a dedicated provider:
|
||||
|
||||
- **PostgresProvider** - Uses `pgx` driver
|
||||
- **SQLiteProvider** - Uses `glebarez/sqlite` (pure Go)
|
||||
- **MSSQLProvider** - Uses `go-mssqldb`
|
||||
- **MongoProvider** - Uses official `mongo-driver`
|
||||
|
||||
Providers handle:
|
||||
- Connection establishment with retry logic
|
||||
- Health checking
|
||||
- Connection statistics
|
||||
- Connection cleanup
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||
```go
|
||||
primary, _ := mgr.Get("primary") // Good
|
||||
db, _ := mgr.GetDefaultDatabase() // Risky if default changes
|
||||
```
|
||||
|
||||
2. **Configure Connection Pools**: Tune based on your workload
|
||||
```yaml
|
||||
connections:
|
||||
primary:
|
||||
max_open_conns: 100 # High traffic API
|
||||
max_idle_conns: 25
|
||||
analytics:
|
||||
max_open_conns: 10 # Background analytics
|
||||
max_idle_conns: 2
|
||||
```
|
||||
|
||||
3. **Enable Health Checks**: Catch connection issues early
|
||||
```yaml
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
```
|
||||
|
||||
4. **Use Appropriate ORM**: Choose based on your needs
|
||||
- **Bun**: Modern, fast, type-safe - recommended for new code
|
||||
- **GORM**: Mature, feature-rich - good for existing GORM code
|
||||
- **Native**: Maximum control - use for performance-critical queries
|
||||
|
||||
5. **Monitor Metrics**: Watch connection pool utilization
|
||||
- If `wait_count` is high, increase `max_open_conns`
|
||||
- If `idle` is always high, decrease `max_idle_conns`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Failures
|
||||
|
||||
If connections fail to establish:
|
||||
|
||||
1. Check configuration:
|
||||
```bash
|
||||
# Test connection manually
|
||||
psql -h localhost -U myuser -d myapp
|
||||
```
|
||||
|
||||
2. Enable logging:
|
||||
```yaml
|
||||
connections:
|
||||
primary:
|
||||
enable_logging: true
|
||||
```
|
||||
|
||||
3. Check retry attempts:
|
||||
```yaml
|
||||
retry_attempts: 5 # Increase retries
|
||||
retry_max_delay: 30s
|
||||
```
|
||||
|
||||
### Pool Exhaustion
|
||||
|
||||
If you see "too many connections" errors:
|
||||
|
||||
1. Increase pool size:
|
||||
```yaml
|
||||
max_open_conns: 50 # Increase from default 25
|
||||
```
|
||||
|
||||
2. Reduce connection lifetime:
|
||||
```yaml
|
||||
conn_max_lifetime: 15m # Recycle faster
|
||||
```
|
||||
|
||||
3. Monitor wait stats:
|
||||
```go
|
||||
stats := primary.Stats()
|
||||
if stats.WaitCount > 1000 {
|
||||
log.Warn("High connection wait count")
|
||||
}
|
||||
```
|
||||
|
||||
### MongoDB vs SQL Confusion
|
||||
|
||||
MongoDB connections don't support SQL ORMs:
|
||||
|
||||
```go
|
||||
docs, _ := mgr.Get("documents")
|
||||
|
||||
// ✓ Correct
|
||||
mongoClient, _ := docs.MongoDB()
|
||||
|
||||
// ✗ Error: ErrNotSQLDatabase
|
||||
bunDB, err := docs.Bun() // Won't work!
|
||||
```
|
||||
|
||||
SQL connections don't support MongoDB:
|
||||
|
||||
```go
|
||||
primary, _ := mgr.Get("primary")
|
||||
|
||||
// ✓ Correct
|
||||
bunDB, _ := primary.Bun()
|
||||
|
||||
// ✗ Error: ErrNotMongoDB
|
||||
mongoClient, err := primary.MongoDB() // Won't work!
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Raw `database/sql`
|
||||
|
||||
Before:
|
||||
```go
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query("SELECT * FROM users")
|
||||
```
|
||||
|
||||
After:
|
||||
```go
|
||||
mgr, _ := dbmanager.NewManager(cfg.DBManager)
|
||||
mgr.Connect(ctx)
|
||||
defer mgr.Close()
|
||||
|
||||
primary, _ := mgr.Get("primary")
|
||||
nativeDB, _ := primary.Native()
|
||||
|
||||
rows, err := nativeDB.Query("SELECT * FROM users")
|
||||
```
|
||||
|
||||
### From Direct Bun/GORM
|
||||
|
||||
Before:
|
||||
```go
|
||||
sqldb, _ := sql.Open("pgx", dsn)
|
||||
bunDB := bun.NewDB(sqldb, pgdialect.New())
|
||||
|
||||
var users []User
|
||||
bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||
```
|
||||
|
||||
After:
|
||||
```go
|
||||
mgr, _ := dbmanager.NewManager(cfg.DBManager)
|
||||
mgr.Connect(ctx)
|
||||
|
||||
primary, _ := mgr.Get("primary")
|
||||
bunDB, _ := primary.Bun()
|
||||
|
||||
var users []User
|
||||
bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Same as the parent project.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please submit issues and pull requests to the main repository.
|
||||
448
pkg/dbmanager/config.go
Normal file
448
pkg/dbmanager/config.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// DatabaseType represents the type of database
|
||||
type DatabaseType string
|
||||
|
||||
const (
|
||||
// DatabaseTypePostgreSQL represents PostgreSQL database
|
||||
DatabaseTypePostgreSQL DatabaseType = "postgres"
|
||||
|
||||
// DatabaseTypeSQLite represents SQLite database
|
||||
DatabaseTypeSQLite DatabaseType = "sqlite"
|
||||
|
||||
// DatabaseTypeMSSQL represents Microsoft SQL Server database
|
||||
DatabaseTypeMSSQL DatabaseType = "mssql"
|
||||
|
||||
// DatabaseTypeMongoDB represents MongoDB database
|
||||
DatabaseTypeMongoDB DatabaseType = "mongodb"
|
||||
)
|
||||
|
||||
// ORMType represents the ORM to use for database operations
|
||||
type ORMType string
|
||||
|
||||
const (
|
||||
// ORMTypeBun represents Bun ORM
|
||||
ORMTypeBun ORMType = "bun"
|
||||
|
||||
// ORMTypeGORM represents GORM
|
||||
ORMTypeGORM ORMType = "gorm"
|
||||
|
||||
// ORMTypeNative represents native database/sql
|
||||
ORMTypeNative ORMType = "native"
|
||||
)
|
||||
|
||||
// ManagerConfig contains configuration for the database connection manager
|
||||
type ManagerConfig struct {
|
||||
// DefaultConnection is the name of the default connection to use
|
||||
DefaultConnection string `mapstructure:"default_connection"`
|
||||
|
||||
// Connections is a map of connection name to connection configuration
|
||||
Connections map[string]ConnectionConfig `mapstructure:"connections"`
|
||||
|
||||
// Global connection pool defaults
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
|
||||
|
||||
// Retry policy
|
||||
RetryAttempts int `mapstructure:"retry_attempts"`
|
||||
RetryDelay time.Duration `mapstructure:"retry_delay"`
|
||||
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||
|
||||
// Health checks
|
||||
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
|
||||
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
|
||||
}
|
||||
|
||||
// ConnectionConfig defines configuration for a single database connection
|
||||
type ConnectionConfig struct {
|
||||
// Name is the unique name of this connection
|
||||
Name string `mapstructure:"name"`
|
||||
|
||||
// Type is the database type (postgres, sqlite, mssql, mongodb)
|
||||
Type DatabaseType `mapstructure:"type"`
|
||||
|
||||
// DSN is the complete Data Source Name / connection string
|
||||
// If provided, this takes precedence over individual connection parameters
|
||||
DSN string `mapstructure:"dsn"`
|
||||
|
||||
// Connection parameters (used if DSN is not provided)
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database string `mapstructure:"database"`
|
||||
|
||||
// PostgreSQL/MSSQL specific
|
||||
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
|
||||
Schema string `mapstructure:"schema"` // Default schema
|
||||
|
||||
// SQLite specific
|
||||
FilePath string `mapstructure:"filepath"`
|
||||
|
||||
// MongoDB specific
|
||||
AuthSource string `mapstructure:"auth_source"`
|
||||
ReplicaSet string `mapstructure:"replica_set"`
|
||||
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
|
||||
|
||||
// Connection pool settings (overrides global defaults)
|
||||
MaxOpenConns *int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns *int `mapstructure:"max_idle_conns"`
|
||||
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
|
||||
|
||||
// Timeouts
|
||||
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
|
||||
QueryTimeout time.Duration `mapstructure:"query_timeout"`
|
||||
|
||||
// Features
|
||||
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||
EnableMetrics bool `mapstructure:"enable_metrics"`
|
||||
EnableLogging bool `mapstructure:"enable_logging"`
|
||||
|
||||
// DefaultORM specifies which ORM to use for the Database() method
|
||||
// Options: "bun", "gorm", "native"
|
||||
DefaultORM string `mapstructure:"default_orm"`
|
||||
|
||||
// Tags for organization and filtering
|
||||
Tags map[string]string `mapstructure:"tags"`
|
||||
}
|
||||
|
||||
// DefaultManagerConfig returns a ManagerConfig with sensible defaults
|
||||
func DefaultManagerConfig() ManagerConfig {
|
||||
return ManagerConfig{
|
||||
DefaultConnection: "",
|
||||
Connections: make(map[string]ConnectionConfig),
|
||||
MaxOpenConns: 25,
|
||||
MaxIdleConns: 5,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
RetryAttempts: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
RetryMaxDelay: 10 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to the manager configuration
|
||||
func (c *ManagerConfig) ApplyDefaults() {
|
||||
defaults := DefaultManagerConfig()
|
||||
|
||||
if c.MaxOpenConns == 0 {
|
||||
c.MaxOpenConns = defaults.MaxOpenConns
|
||||
}
|
||||
if c.MaxIdleConns == 0 {
|
||||
c.MaxIdleConns = defaults.MaxIdleConns
|
||||
}
|
||||
if c.ConnMaxLifetime == 0 {
|
||||
c.ConnMaxLifetime = defaults.ConnMaxLifetime
|
||||
}
|
||||
if c.ConnMaxIdleTime == 0 {
|
||||
c.ConnMaxIdleTime = defaults.ConnMaxIdleTime
|
||||
}
|
||||
if c.RetryAttempts == 0 {
|
||||
c.RetryAttempts = defaults.RetryAttempts
|
||||
}
|
||||
if c.RetryDelay == 0 {
|
||||
c.RetryDelay = defaults.RetryDelay
|
||||
}
|
||||
if c.RetryMaxDelay == 0 {
|
||||
c.RetryMaxDelay = defaults.RetryMaxDelay
|
||||
}
|
||||
if c.HealthCheckInterval == 0 {
|
||||
c.HealthCheckInterval = defaults.HealthCheckInterval
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the manager configuration
|
||||
func (c *ManagerConfig) Validate() error {
|
||||
if len(c.Connections) == 0 {
|
||||
return NewConfigurationError("connections", fmt.Errorf("at least one connection must be configured"))
|
||||
}
|
||||
|
||||
if c.DefaultConnection != "" {
|
||||
if _, ok := c.Connections[c.DefaultConnection]; !ok {
|
||||
return NewConfigurationError("default_connection", fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection))
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each connection
|
||||
for name := range c.Connections {
|
||||
conn := c.Connections[name]
|
||||
if err := conn.Validate(); err != nil {
|
||||
return fmt.Errorf("connection '%s': %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values and global settings to the connection configuration
|
||||
func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
|
||||
// Set name if not already set
|
||||
if cc.Name == "" {
|
||||
cc.Name = "unnamed"
|
||||
}
|
||||
|
||||
// Apply global pool settings if not overridden
|
||||
if cc.MaxOpenConns == nil && global != nil {
|
||||
maxOpen := global.MaxOpenConns
|
||||
cc.MaxOpenConns = &maxOpen
|
||||
}
|
||||
if cc.MaxIdleConns == nil && global != nil {
|
||||
maxIdle := global.MaxIdleConns
|
||||
cc.MaxIdleConns = &maxIdle
|
||||
}
|
||||
if cc.ConnMaxLifetime == nil && global != nil {
|
||||
lifetime := global.ConnMaxLifetime
|
||||
cc.ConnMaxLifetime = &lifetime
|
||||
}
|
||||
if cc.ConnMaxIdleTime == nil && global != nil {
|
||||
idleTime := global.ConnMaxIdleTime
|
||||
cc.ConnMaxIdleTime = &idleTime
|
||||
}
|
||||
|
||||
// Default timeouts
|
||||
if cc.ConnectTimeout == 0 {
|
||||
cc.ConnectTimeout = 10 * time.Second
|
||||
}
|
||||
if cc.QueryTimeout == 0 {
|
||||
cc.QueryTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
// Default ORM
|
||||
if cc.DefaultORM == "" {
|
||||
cc.DefaultORM = string(ORMTypeBun)
|
||||
}
|
||||
|
||||
// Default PostgreSQL port
|
||||
if cc.Type == DatabaseTypePostgreSQL && cc.Port == 0 && cc.DSN == "" {
|
||||
cc.Port = 5432
|
||||
}
|
||||
|
||||
// Default MSSQL port
|
||||
if cc.Type == DatabaseTypeMSSQL && cc.Port == 0 && cc.DSN == "" {
|
||||
cc.Port = 1433
|
||||
}
|
||||
|
||||
// Default MongoDB port
|
||||
if cc.Type == DatabaseTypeMongoDB && cc.Port == 0 && cc.DSN == "" {
|
||||
cc.Port = 27017
|
||||
}
|
||||
|
||||
// Default MongoDB auth source
|
||||
if cc.Type == DatabaseTypeMongoDB && cc.AuthSource == "" {
|
||||
cc.AuthSource = "admin"
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the connection configuration
|
||||
func (cc *ConnectionConfig) Validate() error {
|
||||
// Validate database type
|
||||
switch cc.Type {
|
||||
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
|
||||
// Valid types
|
||||
default:
|
||||
return NewConfigurationError("type", fmt.Errorf("unsupported database type: %s", cc.Type))
|
||||
}
|
||||
|
||||
// Validate that either DSN or connection parameters are provided
|
||||
if cc.DSN == "" {
|
||||
switch cc.Type {
|
||||
case DatabaseTypePostgreSQL, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
|
||||
if cc.Host == "" {
|
||||
return NewConfigurationError("host", fmt.Errorf("host is required when DSN is not provided"))
|
||||
}
|
||||
if cc.Database == "" {
|
||||
return NewConfigurationError("database", fmt.Errorf("database is required when DSN is not provided"))
|
||||
}
|
||||
case DatabaseTypeSQLite:
|
||||
if cc.FilePath == "" {
|
||||
return NewConfigurationError("filepath", fmt.Errorf("filepath is required for SQLite when DSN is not provided"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate ORM type
|
||||
if cc.DefaultORM != "" {
|
||||
switch ORMType(cc.DefaultORM) {
|
||||
case ORMTypeBun, ORMTypeGORM, ORMTypeNative:
|
||||
// Valid ORM types
|
||||
default:
|
||||
return NewConfigurationError("default_orm", fmt.Errorf("unsupported ORM type: %s", cc.DefaultORM))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildDSN builds a connection string from individual parameters
|
||||
func (cc *ConnectionConfig) BuildDSN() (string, error) {
|
||||
// If DSN is already provided, use it
|
||||
if cc.DSN != "" {
|
||||
return cc.DSN, nil
|
||||
}
|
||||
|
||||
switch cc.Type {
|
||||
case DatabaseTypePostgreSQL:
|
||||
return cc.buildPostgresDSN(), nil
|
||||
case DatabaseTypeSQLite:
|
||||
return cc.buildSQLiteDSN(), nil
|
||||
case DatabaseTypeMSSQL:
|
||||
return cc.buildMSSQLDSN(), nil
|
||||
case DatabaseTypeMongoDB:
|
||||
return cc.buildMongoDSN(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot build DSN for database type: %s", cc.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildPostgresDSN() string {
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s",
|
||||
cc.Host, cc.Port, cc.User, cc.Password, cc.Database)
|
||||
|
||||
if cc.SSLMode != "" {
|
||||
dsn += fmt.Sprintf(" sslmode=%s", cc.SSLMode)
|
||||
} else {
|
||||
dsn += " sslmode=disable"
|
||||
}
|
||||
|
||||
if cc.Schema != "" {
|
||||
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildSQLiteDSN() string {
|
||||
if cc.FilePath != "" {
|
||||
return cc.FilePath
|
||||
}
|
||||
return ":memory:"
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||
// Format: sqlserver://username:password@host:port?database=dbname
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
|
||||
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
|
||||
|
||||
if cc.Schema != "" {
|
||||
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildMongoDSN() string {
|
||||
// Format: mongodb://username:password@host:port/database?authSource=admin
|
||||
var dsn string
|
||||
|
||||
if cc.User != "" && cc.Password != "" {
|
||||
dsn = fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
|
||||
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
|
||||
} else {
|
||||
dsn = fmt.Sprintf("mongodb://%s:%d/%s", cc.Host, cc.Port, cc.Database)
|
||||
}
|
||||
|
||||
params := ""
|
||||
if cc.AuthSource != "" {
|
||||
params += fmt.Sprintf("authSource=%s", cc.AuthSource)
|
||||
}
|
||||
if cc.ReplicaSet != "" {
|
||||
if params != "" {
|
||||
params += "&"
|
||||
}
|
||||
params += fmt.Sprintf("replicaSet=%s", cc.ReplicaSet)
|
||||
}
|
||||
if cc.ReadPreference != "" {
|
||||
if params != "" {
|
||||
params += "&"
|
||||
}
|
||||
params += fmt.Sprintf("readPreference=%s", cc.ReadPreference)
|
||||
}
|
||||
|
||||
if params != "" {
|
||||
dsn += "?" + params
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
// FromConfig converts config.DBManagerConfig to internal ManagerConfig
|
||||
func FromConfig(cfg config.DBManagerConfig) ManagerConfig {
|
||||
mgr := ManagerConfig{
|
||||
DefaultConnection: cfg.DefaultConnection,
|
||||
Connections: make(map[string]ConnectionConfig),
|
||||
MaxOpenConns: cfg.MaxOpenConns,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
ConnMaxLifetime: cfg.ConnMaxLifetime,
|
||||
ConnMaxIdleTime: cfg.ConnMaxIdleTime,
|
||||
RetryAttempts: cfg.RetryAttempts,
|
||||
RetryDelay: cfg.RetryDelay,
|
||||
RetryMaxDelay: cfg.RetryMaxDelay,
|
||||
HealthCheckInterval: cfg.HealthCheckInterval,
|
||||
EnableAutoReconnect: cfg.EnableAutoReconnect,
|
||||
}
|
||||
|
||||
// Convert connections
|
||||
for name := range cfg.Connections {
|
||||
connCfg := cfg.Connections[name]
|
||||
mgr.Connections[name] = ConnectionConfig{
|
||||
Name: connCfg.Name,
|
||||
Type: DatabaseType(connCfg.Type),
|
||||
DSN: connCfg.DSN,
|
||||
Host: connCfg.Host,
|
||||
Port: connCfg.Port,
|
||||
User: connCfg.User,
|
||||
Password: connCfg.Password,
|
||||
Database: connCfg.Database,
|
||||
SSLMode: connCfg.SSLMode,
|
||||
Schema: connCfg.Schema,
|
||||
FilePath: connCfg.FilePath,
|
||||
AuthSource: connCfg.AuthSource,
|
||||
ReplicaSet: connCfg.ReplicaSet,
|
||||
ReadPreference: connCfg.ReadPreference,
|
||||
MaxOpenConns: connCfg.MaxOpenConns,
|
||||
MaxIdleConns: connCfg.MaxIdleConns,
|
||||
ConnMaxLifetime: connCfg.ConnMaxLifetime,
|
||||
ConnMaxIdleTime: connCfg.ConnMaxIdleTime,
|
||||
ConnectTimeout: connCfg.ConnectTimeout,
|
||||
QueryTimeout: connCfg.QueryTimeout,
|
||||
EnableTracing: connCfg.EnableTracing,
|
||||
EnableMetrics: connCfg.EnableMetrics,
|
||||
EnableLogging: connCfg.EnableLogging,
|
||||
DefaultORM: connCfg.DefaultORM,
|
||||
Tags: connCfg.Tags,
|
||||
}
|
||||
}
|
||||
|
||||
return mgr
|
||||
}
|
||||
|
||||
// Getter methods to implement providers.ConnectionConfig interface
|
||||
func (cc *ConnectionConfig) GetName() string { return cc.Name }
|
||||
func (cc *ConnectionConfig) GetType() string { return string(cc.Type) }
|
||||
func (cc *ConnectionConfig) GetHost() string { return cc.Host }
|
||||
func (cc *ConnectionConfig) GetPort() int { return cc.Port }
|
||||
func (cc *ConnectionConfig) GetUser() string { return cc.User }
|
||||
func (cc *ConnectionConfig) GetPassword() string { return cc.Password }
|
||||
func (cc *ConnectionConfig) GetDatabase() string { return cc.Database }
|
||||
func (cc *ConnectionConfig) GetFilePath() string { return cc.FilePath }
|
||||
func (cc *ConnectionConfig) GetConnectTimeout() time.Duration { return cc.ConnectTimeout }
|
||||
func (cc *ConnectionConfig) GetEnableLogging() bool { return cc.EnableLogging }
|
||||
func (cc *ConnectionConfig) GetMaxOpenConns() *int { return cc.MaxOpenConns }
|
||||
func (cc *ConnectionConfig) GetMaxIdleConns() *int { return cc.MaxIdleConns }
|
||||
func (cc *ConnectionConfig) GetConnMaxLifetime() *time.Duration { return cc.ConnMaxLifetime }
|
||||
func (cc *ConnectionConfig) GetConnMaxIdleTime() *time.Duration { return cc.ConnMaxIdleTime }
|
||||
func (cc *ConnectionConfig) GetQueryTimeout() time.Duration { return cc.QueryTimeout }
|
||||
func (cc *ConnectionConfig) GetEnableMetrics() bool { return cc.EnableMetrics }
|
||||
func (cc *ConnectionConfig) GetReadPreference() string { return cc.ReadPreference }
|
||||
607
pkg/dbmanager/connection.go
Normal file
607
pkg/dbmanager/connection.go
Normal file
@@ -0,0 +1,607 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/schema"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
)
|
||||
|
||||
// Connection represents a single named database connection
|
||||
type Connection interface {
|
||||
// Metadata
|
||||
Name() string
|
||||
Type() DatabaseType
|
||||
|
||||
// ORM Access (SQL databases only)
|
||||
Bun() (*bun.DB, error)
|
||||
GORM() (*gorm.DB, error)
|
||||
Native() (*sql.DB, error)
|
||||
|
||||
// Common Database interface (for SQL databases)
|
||||
Database() (common.Database, error)
|
||||
|
||||
// MongoDB Access (MongoDB only)
|
||||
MongoDB() (*mongo.Client, error)
|
||||
|
||||
// Lifecycle
|
||||
Connect(ctx context.Context) error
|
||||
Close() error
|
||||
HealthCheck(ctx context.Context) error
|
||||
Reconnect(ctx context.Context) error
|
||||
|
||||
// Stats
|
||||
Stats() *ConnectionStats
|
||||
}
|
||||
|
||||
// ConnectionStats contains statistics about a database connection
|
||||
type ConnectionStats struct {
|
||||
Name string
|
||||
Type DatabaseType
|
||||
Connected bool
|
||||
LastHealthCheck time.Time
|
||||
HealthCheckStatus string
|
||||
|
||||
// SQL connection pool stats
|
||||
OpenConnections int
|
||||
InUse int
|
||||
Idle int
|
||||
WaitCount int64
|
||||
WaitDuration time.Duration
|
||||
MaxIdleClosed int64
|
||||
MaxLifetimeClosed int64
|
||||
}
|
||||
|
||||
// sqlConnection implements Connection for SQL databases (PostgreSQL, SQLite, MSSQL)
|
||||
type sqlConnection struct {
|
||||
name string
|
||||
dbType DatabaseType
|
||||
config ConnectionConfig
|
||||
provider Provider
|
||||
|
||||
// Lazy-initialized ORM instances (all wrap the same sql.DB)
|
||||
nativeDB *sql.DB
|
||||
bunDB *bun.DB
|
||||
gormDB *gorm.DB
|
||||
|
||||
// Adapters for common.Database interface
|
||||
bunAdapter *database.BunAdapter
|
||||
gormAdapter *database.GormAdapter
|
||||
nativeAdapter common.Database
|
||||
|
||||
// State
|
||||
connected bool
|
||||
mu sync.RWMutex
|
||||
|
||||
// Health check
|
||||
lastHealthCheck time.Time
|
||||
healthCheckStatus string
|
||||
}
|
||||
|
||||
// newSQLConnection creates a new SQL connection
|
||||
func newSQLConnection(name string, dbType DatabaseType, config ConnectionConfig, provider Provider) *sqlConnection {
|
||||
return &sqlConnection{
|
||||
name: name,
|
||||
dbType: dbType,
|
||||
config: config,
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the connection name
|
||||
func (c *sqlConnection) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns the database type
|
||||
func (c *sqlConnection) Type() DatabaseType {
|
||||
return c.dbType
|
||||
}
|
||||
|
||||
// Connect establishes the database connection
|
||||
func (c *sqlConnection) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connected {
|
||||
return ErrAlreadyConnected
|
||||
}
|
||||
|
||||
if err := c.provider.Connect(ctx, &c.config); err != nil {
|
||||
return NewConnectionError(c.name, "connect", err)
|
||||
}
|
||||
|
||||
c.connected = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection and all ORM instances
|
||||
func (c *sqlConnection) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close Bun if initialized
|
||||
if c.bunDB != nil {
|
||||
if err := c.bunDB.Close(); err != nil {
|
||||
return NewConnectionError(c.name, "close bun", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GORM doesn't have a separate close - it uses the underlying sql.DB
|
||||
|
||||
// Close the provider (which closes the underlying sql.DB)
|
||||
if err := c.provider.Close(); err != nil {
|
||||
return NewConnectionError(c.name, "close", err)
|
||||
}
|
||||
|
||||
c.connected = false
|
||||
c.nativeDB = nil
|
||||
c.bunDB = nil
|
||||
c.gormDB = nil
|
||||
c.bunAdapter = nil
|
||||
c.gormAdapter = nil
|
||||
c.nativeAdapter = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the connection is alive
|
||||
func (c *sqlConnection) HealthCheck(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastHealthCheck = time.Now()
|
||||
|
||||
if !c.connected {
|
||||
c.healthCheckStatus = "disconnected"
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if err := c.provider.HealthCheck(ctx); err != nil {
|
||||
c.healthCheckStatus = "unhealthy: " + err.Error()
|
||||
return NewConnectionError(c.name, "health check", err)
|
||||
}
|
||||
|
||||
c.healthCheckStatus = "healthy"
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconnect closes and re-establishes the connection
|
||||
func (c *sqlConnection) Reconnect(ctx context.Context) error {
|
||||
if err := c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Connect(ctx)
|
||||
}
|
||||
|
||||
// Native returns the native *sql.DB connection
|
||||
func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||
c.mu.RLock()
|
||||
if c.nativeDB != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.nativeDB != nil {
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
if !c.connected {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Get native connection from provider
|
||||
db, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get native", err)
|
||||
}
|
||||
|
||||
c.nativeDB = db
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
// Bun returns a Bun ORM instance wrapping the native connection
|
||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||
c.mu.RLock()
|
||||
if c.bunDB != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.bunDB, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.bunDB != nil {
|
||||
return c.bunDB, nil
|
||||
}
|
||||
|
||||
// Get native connection first
|
||||
native, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get bun", err)
|
||||
}
|
||||
|
||||
// Create Bun DB wrapping the same sql.DB
|
||||
dialect := c.getBunDialect()
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
|
||||
return c.bunDB, nil
|
||||
}
|
||||
|
||||
// GORM returns a GORM instance wrapping the native connection
|
||||
func (c *sqlConnection) GORM() (*gorm.DB, error) {
|
||||
c.mu.RLock()
|
||||
if c.gormDB != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.gormDB, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.gormDB != nil {
|
||||
return c.gormDB, nil
|
||||
}
|
||||
|
||||
// Get native connection first
|
||||
native, err := c.provider.GetNative()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "get gorm", err)
|
||||
}
|
||||
|
||||
// Create GORM DB wrapping the same sql.DB
|
||||
dialector := c.getGORMDialector(native)
|
||||
db, err := gorm.Open(dialector, &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(c.name, "initialize gorm", err)
|
||||
}
|
||||
|
||||
c.gormDB = db
|
||||
return c.gormDB, nil
|
||||
}
|
||||
|
||||
// Database returns the common.Database interface using the configured default ORM
|
||||
func (c *sqlConnection) Database() (common.Database, error) {
|
||||
c.mu.RLock()
|
||||
defaultORM := c.config.DefaultORM
|
||||
c.mu.RUnlock()
|
||||
|
||||
switch ORMType(defaultORM) {
|
||||
case ORMTypeBun:
|
||||
return c.getBunAdapter()
|
||||
case ORMTypeGORM:
|
||||
return c.getGORMAdapter()
|
||||
case ORMTypeNative:
|
||||
return c.getNativeAdapter()
|
||||
default:
|
||||
// Default to Bun
|
||||
return c.getBunAdapter()
|
||||
}
|
||||
}
|
||||
|
||||
// MongoDB returns an error for SQL connections
|
||||
func (c *sqlConnection) MongoDB() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection statistics
|
||||
func (c *sqlConnection) Stats() *ConnectionStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
stats := &ConnectionStats{
|
||||
Name: c.name,
|
||||
Type: c.dbType,
|
||||
Connected: c.connected,
|
||||
LastHealthCheck: c.lastHealthCheck,
|
||||
HealthCheckStatus: c.healthCheckStatus,
|
||||
}
|
||||
|
||||
// Get SQL stats if connected
|
||||
if c.connected && c.provider != nil {
|
||||
if providerStats := c.provider.Stats(); providerStats != nil {
|
||||
stats.OpenConnections = providerStats.OpenConnections
|
||||
stats.InUse = providerStats.InUse
|
||||
stats.Idle = providerStats.Idle
|
||||
stats.WaitCount = providerStats.WaitCount
|
||||
stats.WaitDuration = providerStats.WaitDuration
|
||||
stats.MaxIdleClosed = providerStats.MaxIdleClosed
|
||||
stats.MaxLifetimeClosed = providerStats.MaxLifetimeClosed
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// getBunAdapter returns or creates the Bun adapter
|
||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
c.mu.RLock()
|
||||
if c.bunAdapter != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.bunAdapter != nil {
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
bunDB, err := c.Bun()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(bunDB)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
// getGORMAdapter returns or creates the GORM adapter
|
||||
func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
c.mu.RLock()
|
||||
if c.gormAdapter != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.gormAdapter != nil {
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
gormDB, err := c.GORM()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(gormDB)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
// getNativeAdapter returns or creates the native adapter
|
||||
func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
c.mu.RLock()
|
||||
if c.nativeAdapter != nil {
|
||||
defer c.mu.RUnlock()
|
||||
return c.nativeAdapter, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.nativeAdapter != nil {
|
||||
return c.nativeAdapter, nil
|
||||
}
|
||||
|
||||
native, err := c.Native()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
case DatabaseTypeSQLite:
|
||||
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
case DatabaseTypeMSSQL:
|
||||
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(native)
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
return c.nativeAdapter, nil
|
||||
}
|
||||
|
||||
// getBunDialect returns the appropriate Bun dialect for the database type
|
||||
func (c *sqlConnection) getBunDialect() schema.Dialect {
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
return database.GetPostgresDialect()
|
||||
case DatabaseTypeSQLite:
|
||||
return database.GetSQLiteDialect()
|
||||
case DatabaseTypeMSSQL:
|
||||
return database.GetMSSQLDialect()
|
||||
default:
|
||||
// Default to PostgreSQL
|
||||
return database.GetPostgresDialect()
|
||||
}
|
||||
}
|
||||
|
||||
// getGORMDialector returns the appropriate GORM dialector for the database type
|
||||
func (c *sqlConnection) getGORMDialector(db *sql.DB) gorm.Dialector {
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
return database.GetPostgresDialector(db)
|
||||
case DatabaseTypeSQLite:
|
||||
return database.GetSQLiteDialector(db)
|
||||
case DatabaseTypeMSSQL:
|
||||
return database.GetMSSQLDialector(db)
|
||||
default:
|
||||
// Default to PostgreSQL
|
||||
return database.GetPostgresDialector(db)
|
||||
}
|
||||
}
|
||||
|
||||
// mongoConnection implements Connection for MongoDB
|
||||
type mongoConnection struct {
|
||||
name string
|
||||
config ConnectionConfig
|
||||
provider Provider
|
||||
|
||||
// MongoDB client
|
||||
client *mongo.Client
|
||||
|
||||
// State
|
||||
connected bool
|
||||
mu sync.RWMutex
|
||||
|
||||
// Health check
|
||||
lastHealthCheck time.Time
|
||||
healthCheckStatus string
|
||||
}
|
||||
|
||||
// newMongoConnection creates a new MongoDB connection
|
||||
func newMongoConnection(name string, config ConnectionConfig, provider Provider) *mongoConnection {
|
||||
return &mongoConnection{
|
||||
name: name,
|
||||
config: config,
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the connection name
|
||||
func (c *mongoConnection) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns the database type (MongoDB)
|
||||
func (c *mongoConnection) Type() DatabaseType {
|
||||
return DatabaseTypeMongoDB
|
||||
}
|
||||
|
||||
// Connect establishes the MongoDB connection
|
||||
func (c *mongoConnection) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connected {
|
||||
return ErrAlreadyConnected
|
||||
}
|
||||
|
||||
if err := c.provider.Connect(ctx, &c.config); err != nil {
|
||||
return NewConnectionError(c.name, "connect", err)
|
||||
}
|
||||
|
||||
// Get the mongo client
|
||||
client, err := c.provider.GetMongo()
|
||||
if err != nil {
|
||||
return NewConnectionError(c.name, "get mongo client", err)
|
||||
}
|
||||
|
||||
c.client = client
|
||||
c.connected = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the MongoDB connection
|
||||
func (c *mongoConnection) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.provider.Close(); err != nil {
|
||||
return NewConnectionError(c.name, "close", err)
|
||||
}
|
||||
|
||||
c.connected = false
|
||||
c.client = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the MongoDB connection is alive
|
||||
func (c *mongoConnection) HealthCheck(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastHealthCheck = time.Now()
|
||||
|
||||
if !c.connected {
|
||||
c.healthCheckStatus = "disconnected"
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if err := c.provider.HealthCheck(ctx); err != nil {
|
||||
c.healthCheckStatus = "unhealthy: " + err.Error()
|
||||
return NewConnectionError(c.name, "health check", err)
|
||||
}
|
||||
|
||||
c.healthCheckStatus = "healthy"
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconnect closes and re-establishes the MongoDB connection
|
||||
func (c *mongoConnection) Reconnect(ctx context.Context) error {
|
||||
if err := c.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Connect(ctx)
|
||||
}
|
||||
|
||||
// MongoDB returns the MongoDB client
|
||||
func (c *mongoConnection) MongoDB() (*mongo.Client, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if !c.connected || c.client == nil {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
// Bun returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Bun() (*bun.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// GORM returns an error for MongoDB connections
|
||||
func (c *mongoConnection) GORM() (*gorm.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Native returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Database returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Database() (common.Database, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Stats returns connection statistics for MongoDB
|
||||
func (c *mongoConnection) Stats() *ConnectionStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return &ConnectionStats{
|
||||
Name: c.name,
|
||||
Type: DatabaseTypeMongoDB,
|
||||
Connected: c.connected,
|
||||
LastHealthCheck: c.lastHealthCheck,
|
||||
HealthCheckStatus: c.healthCheckStatus,
|
||||
}
|
||||
}
|
||||
82
pkg/dbmanager/errors.go
Normal file
82
pkg/dbmanager/errors.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrConnectionNotFound is returned when a connection with the given name doesn't exist
|
||||
ErrConnectionNotFound = errors.New("connection not found")
|
||||
|
||||
// ErrInvalidConfiguration is returned when the configuration is invalid
|
||||
ErrInvalidConfiguration = errors.New("invalid configuration")
|
||||
|
||||
// ErrConnectionClosed is returned when attempting to use a closed connection
|
||||
ErrConnectionClosed = errors.New("connection is closed")
|
||||
|
||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||
ErrNotSQLDatabase = errors.New("not a SQL database")
|
||||
|
||||
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
|
||||
ErrNotMongoDB = errors.New("not a MongoDB connection")
|
||||
|
||||
// ErrUnsupportedDatabase is returned when the database type is not supported
|
||||
ErrUnsupportedDatabase = errors.New("unsupported database type")
|
||||
|
||||
// ErrNoDefaultConnection is returned when no default connection is configured
|
||||
ErrNoDefaultConnection = errors.New("no default connection configured")
|
||||
|
||||
// ErrAlreadyConnected is returned when attempting to connect an already connected connection
|
||||
ErrAlreadyConnected = errors.New("already connected")
|
||||
)
|
||||
|
||||
// ConnectionError wraps errors that occur during connection operations
|
||||
type ConnectionError struct {
|
||||
Name string
|
||||
Operation string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConnectionError) Error() string {
|
||||
return fmt.Sprintf("connection '%s' %s: %v", e.Name, e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ConnectionError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewConnectionError creates a new ConnectionError
|
||||
func NewConnectionError(name, operation string, err error) *ConnectionError {
|
||||
return &ConnectionError{
|
||||
Name: name,
|
||||
Operation: operation,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigurationError wraps configuration-related errors
|
||||
type ConfigurationError struct {
|
||||
Field string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConfigurationError) Error() string {
|
||||
if e.Field != "" {
|
||||
return fmt.Sprintf("configuration error in field '%s': %v", e.Field, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("configuration error: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *ConfigurationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewConfigurationError creates a new ConfigurationError
|
||||
func NewConfigurationError(field string, err error) *ConfigurationError {
|
||||
return &ConfigurationError{
|
||||
Field: field,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
51
pkg/dbmanager/factory.go
Normal file
51
pkg/dbmanager/factory.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
// createConnection creates a database connection based on the configuration
|
||||
func createConnection(cfg ConnectionConfig) (Connection, error) {
|
||||
// Validate configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid connection configuration: %w", err)
|
||||
}
|
||||
|
||||
// Create provider based on database type
|
||||
provider, err := createProvider(cfg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create connection wrapper based on database type
|
||||
switch cfg.Type {
|
||||
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL:
|
||||
return newSQLConnection(cfg.Name, cfg.Type, cfg, provider), nil
|
||||
case DatabaseTypeMongoDB:
|
||||
return newMongoConnection(cfg.Name, cfg, provider), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// createProvider creates a database provider based on the database type
|
||||
func createProvider(dbType DatabaseType) (Provider, error) {
|
||||
switch dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
return providers.NewPostgresProvider(), nil
|
||||
case DatabaseTypeSQLite:
|
||||
return providers.NewSQLiteProvider(), nil
|
||||
case DatabaseTypeMSSQL:
|
||||
return providers.NewMSSQLProvider(), nil
|
||||
case DatabaseTypeMongoDB:
|
||||
return providers.NewMongoProvider(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, dbType)
|
||||
}
|
||||
}
|
||||
|
||||
// Provider is an alias to the providers.Provider interface
|
||||
// This allows dbmanager package consumers to use Provider without importing providers
|
||||
type Provider = providers.Provider
|
||||
379
pkg/dbmanager/manager.go
Normal file
379
pkg/dbmanager/manager.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Manager manages multiple named database connections
|
||||
type Manager interface {
|
||||
// Connection retrieval
|
||||
Get(name string) (Connection, error)
|
||||
GetDefault() (Connection, error)
|
||||
GetAll() map[string]Connection
|
||||
|
||||
// Default database management
|
||||
GetDefaultDatabase() (common.Database, error)
|
||||
SetDefaultDatabase(name string) error
|
||||
|
||||
// Lifecycle
|
||||
Connect(ctx context.Context) error
|
||||
Close() error
|
||||
HealthCheck(ctx context.Context) error
|
||||
|
||||
// Stats
|
||||
Stats() *ManagerStats
|
||||
}
|
||||
|
||||
// ManagerStats contains statistics about the connection manager
|
||||
type ManagerStats struct {
|
||||
TotalConnections int
|
||||
HealthyCount int
|
||||
UnhealthyCount int
|
||||
ConnectionStats map[string]*ConnectionStats
|
||||
}
|
||||
|
||||
// connectionManager implements Manager
|
||||
type connectionManager struct {
|
||||
connections map[string]Connection
|
||||
config ManagerConfig
|
||||
mu sync.RWMutex
|
||||
|
||||
// Background health check
|
||||
healthTicker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
var (
|
||||
// singleton instance of the manager
|
||||
instance Manager
|
||||
// instanceMu protects the singleton instance
|
||||
instanceMu sync.RWMutex
|
||||
)
|
||||
|
||||
// SetupManager initializes the singleton database manager with the provided configuration.
|
||||
// This function must be called before GetInstance().
|
||||
// Returns an error if the manager is already initialized or if configuration is invalid.
|
||||
func SetupManager(cfg ManagerConfig) error {
|
||||
instanceMu.Lock()
|
||||
defer instanceMu.Unlock()
|
||||
|
||||
if instance != nil {
|
||||
return fmt.Errorf("manager already initialized")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create manager: %w", err)
|
||||
}
|
||||
|
||||
instance = mgr
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInstance returns the singleton instance of the database manager.
|
||||
// Returns an error if SetupManager has not been called yet.
|
||||
func GetInstance() (Manager, error) {
|
||||
instanceMu.RLock()
|
||||
defer instanceMu.RUnlock()
|
||||
|
||||
if instance == nil {
|
||||
return nil, fmt.Errorf("manager not initialized: call SetupManager first")
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// ResetInstance resets the singleton instance (primarily for testing purposes).
|
||||
// WARNING: This should only be used in tests. Calling this in production code
|
||||
// while the manager is in use can lead to undefined behavior.
|
||||
func ResetInstance() {
|
||||
instanceMu.Lock()
|
||||
defer instanceMu.Unlock()
|
||||
|
||||
if instance != nil {
|
||||
_ = instance.Close()
|
||||
}
|
||||
instance = nil
|
||||
}
|
||||
|
||||
// NewManager creates a new database connection manager
|
||||
func NewManager(cfg ManagerConfig) (Manager, error) {
|
||||
// Apply defaults and validate configuration
|
||||
cfg.ApplyDefaults()
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: make(map[string]Connection),
|
||||
config: cfg,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// Get retrieves a named connection
|
||||
func (m *connectionManager) Get(name string) (Connection, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
conn, ok := m.connections[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// GetDefault retrieves the default connection
|
||||
func (m *connectionManager) GetDefault() (Connection, error) {
|
||||
m.mu.RLock()
|
||||
defaultName := m.config.DefaultConnection
|
||||
m.mu.RUnlock()
|
||||
|
||||
if defaultName == "" {
|
||||
return nil, ErrNoDefaultConnection
|
||||
}
|
||||
|
||||
return m.Get(defaultName)
|
||||
}
|
||||
|
||||
// GetAll returns all connections
|
||||
func (m *connectionManager) GetAll() map[string]Connection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Create a copy to avoid concurrent access issues
|
||||
result := make(map[string]Connection, len(m.connections))
|
||||
for name, conn := range m.connections {
|
||||
result[name] = conn
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetDefaultDatabase returns the common.Database interface from the default connection
|
||||
func (m *connectionManager) GetDefaultDatabase() (common.Database, error) {
|
||||
conn, err := m.GetDefault()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get database from default connection: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// SetDefaultDatabase sets the default database connection by name
|
||||
func (m *connectionManager) SetDefaultDatabase(name string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Verify the connection exists
|
||||
if _, ok := m.connections[name]; !ok {
|
||||
return fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
|
||||
}
|
||||
|
||||
m.config.DefaultConnection = name
|
||||
logger.Info("Default database connection changed: name=%s", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect establishes all configured database connections
|
||||
func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Create connections from configuration
|
||||
for name := range m.config.Connections {
|
||||
// Get a copy of the connection config
|
||||
connCfg := m.config.Connections[name]
|
||||
// Apply global defaults to connection config
|
||||
connCfg.ApplyDefaults(&m.config)
|
||||
connCfg.Name = name
|
||||
|
||||
// Create connection using factory
|
||||
conn, err := createConnection(connCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create connection '%s': %w", name, err)
|
||||
}
|
||||
|
||||
// Connect
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("failed to connect '%s': %w", name, err)
|
||||
}
|
||||
|
||||
m.connections[name] = conn
|
||||
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
|
||||
}
|
||||
|
||||
// Start background health checks if enabled
|
||||
if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 {
|
||||
m.startHealthChecker()
|
||||
}
|
||||
|
||||
logger.Info("Database manager initialized: connections=%d", len(m.connections))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes all database connections
|
||||
func (m *connectionManager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Stop health checker
|
||||
m.stopHealthChecker()
|
||||
|
||||
// Close all connections
|
||||
var errors []error
|
||||
for name, conn := range m.connections {
|
||||
if err := conn.Close(); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to close connection '%s': %w", name, err))
|
||||
logger.Error("Failed to close connection", "name", name, "error", err)
|
||||
} else {
|
||||
logger.Info("Connection closed: name=%s", name)
|
||||
}
|
||||
}
|
||||
|
||||
m.connections = make(map[string]Connection)
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("errors closing connections: %v", errors)
|
||||
}
|
||||
|
||||
logger.Info("Database manager closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck performs health checks on all connections
|
||||
func (m *connectionManager) HealthCheck(ctx context.Context) error {
|
||||
m.mu.RLock()
|
||||
connections := make(map[string]Connection, len(m.connections))
|
||||
for name, conn := range m.connections {
|
||||
connections[name] = conn
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
var errors []error
|
||||
for name, conn := range connections {
|
||||
if err := conn.HealthCheck(ctx); err != nil {
|
||||
errors = append(errors, fmt.Errorf("connection '%s': %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("health check failed for %d connections: %v", len(errors), errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns statistics for all connections
|
||||
func (m *connectionManager) Stats() *ManagerStats {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
stats := &ManagerStats{
|
||||
TotalConnections: len(m.connections),
|
||||
ConnectionStats: make(map[string]*ConnectionStats),
|
||||
}
|
||||
|
||||
for name, conn := range m.connections {
|
||||
connStats := conn.Stats()
|
||||
stats.ConnectionStats[name] = connStats
|
||||
|
||||
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
|
||||
stats.HealthyCount++
|
||||
} else {
|
||||
stats.UnhealthyCount++
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// startHealthChecker starts background health checking
|
||||
func (m *connectionManager) startHealthChecker() {
|
||||
if m.healthTicker != nil {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
m.healthTicker = time.NewTicker(m.config.HealthCheckInterval)
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
logger.Info("Health checker started: interval=%v", m.config.HealthCheckInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.healthTicker.C:
|
||||
m.performHealthCheck()
|
||||
case <-m.stopChan:
|
||||
logger.Info("Health checker stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// stopHealthChecker stops background health checking
|
||||
func (m *connectionManager) stopHealthChecker() {
|
||||
if m.healthTicker != nil {
|
||||
m.healthTicker.Stop()
|
||||
close(m.stopChan)
|
||||
m.wg.Wait()
|
||||
m.healthTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a health check on all connections
|
||||
func (m *connectionManager) performHealthCheck() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
m.mu.RLock()
|
||||
connections := make([]struct {
|
||||
name string
|
||||
conn Connection
|
||||
}, 0, len(m.connections))
|
||||
for name, conn := range m.connections {
|
||||
connections = append(connections, struct {
|
||||
name string
|
||||
conn Connection
|
||||
}{name, conn})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for _, item := range connections {
|
||||
if err := item.conn.HealthCheck(ctx); err != nil {
|
||||
logger.Warn("Health check failed",
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if m.config.EnableAutoReconnect {
|
||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||
if err := item.conn.Reconnect(ctx); err != nil {
|
||||
logger.Error("Reconnection failed",
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
} else {
|
||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
136
pkg/dbmanager/metrics.go
Normal file
136
pkg/dbmanager/metrics.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
var (
|
||||
// connectionsTotal tracks the total number of configured database connections
|
||||
connectionsTotal = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connections_total",
|
||||
Help: "Total number of configured database connections",
|
||||
},
|
||||
[]string{"type"},
|
||||
)
|
||||
|
||||
// connectionStatus tracks connection health status (1=healthy, 0=unhealthy)
|
||||
connectionStatus = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_status",
|
||||
Help: "Connection status (1=healthy, 0=unhealthy)",
|
||||
},
|
||||
[]string{"name", "type"},
|
||||
)
|
||||
|
||||
// connectionPoolSize tracks connection pool sizes
|
||||
connectionPoolSize = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_pool_size",
|
||||
Help: "Current connection pool size",
|
||||
},
|
||||
[]string{"name", "type", "state"}, // state: open, idle, in_use
|
||||
)
|
||||
|
||||
// connectionWaitCount tracks how many times connections had to wait for availability
|
||||
connectionWaitCount = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_wait_count",
|
||||
Help: "Number of times connections had to wait for availability",
|
||||
},
|
||||
[]string{"name", "type"},
|
||||
)
|
||||
|
||||
// connectionWaitDuration tracks total time connections spent waiting
|
||||
connectionWaitDuration = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_wait_duration_seconds",
|
||||
Help: "Total time connections spent waiting for availability",
|
||||
},
|
||||
[]string{"name", "type"},
|
||||
)
|
||||
|
||||
// reconnectAttempts tracks reconnection attempts and their outcomes
|
||||
reconnectAttempts = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "dbmanager_reconnect_attempts_total",
|
||||
Help: "Total number of reconnection attempts",
|
||||
},
|
||||
[]string{"name", "type", "result"}, // result: success, failure
|
||||
)
|
||||
|
||||
// connectionLifetimeClosed tracks connections closed due to max lifetime
|
||||
connectionLifetimeClosed = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_lifetime_closed_total",
|
||||
Help: "Total connections closed due to exceeding max lifetime",
|
||||
},
|
||||
[]string{"name", "type"},
|
||||
)
|
||||
|
||||
// connectionIdleClosed tracks connections closed due to max idle time
|
||||
connectionIdleClosed = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "dbmanager_connection_idle_closed_total",
|
||||
Help: "Total connections closed due to exceeding max idle time",
|
||||
},
|
||||
[]string{"name", "type"},
|
||||
)
|
||||
)
|
||||
|
||||
// PublishMetrics publishes current metrics for all connections
|
||||
func (m *connectionManager) PublishMetrics() {
|
||||
stats := m.Stats()
|
||||
|
||||
// Count connections by type
|
||||
typeCount := make(map[DatabaseType]int)
|
||||
for _, connStats := range stats.ConnectionStats {
|
||||
typeCount[connStats.Type]++
|
||||
}
|
||||
|
||||
// Update total connections gauge
|
||||
for dbType, count := range typeCount {
|
||||
connectionsTotal.WithLabelValues(string(dbType)).Set(float64(count))
|
||||
}
|
||||
|
||||
// Update per-connection metrics
|
||||
for name, connStats := range stats.ConnectionStats {
|
||||
labels := prometheus.Labels{
|
||||
"name": name,
|
||||
"type": string(connStats.Type),
|
||||
}
|
||||
|
||||
// Connection status
|
||||
status := float64(0)
|
||||
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
|
||||
status = 1
|
||||
}
|
||||
connectionStatus.With(labels).Set(status)
|
||||
|
||||
// Pool size metrics (SQL databases only)
|
||||
if connStats.Type != DatabaseTypeMongoDB {
|
||||
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "open").Set(float64(connStats.OpenConnections))
|
||||
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "idle").Set(float64(connStats.Idle))
|
||||
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "in_use").Set(float64(connStats.InUse))
|
||||
|
||||
// Wait stats
|
||||
connectionWaitCount.With(labels).Set(float64(connStats.WaitCount))
|
||||
connectionWaitDuration.With(labels).Set(connStats.WaitDuration.Seconds())
|
||||
|
||||
// Lifetime/idle closed
|
||||
connectionLifetimeClosed.With(labels).Set(float64(connStats.MaxLifetimeClosed))
|
||||
connectionIdleClosed.With(labels).Set(float64(connStats.MaxIdleClosed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordReconnectAttempt records a reconnection attempt
|
||||
func RecordReconnectAttempt(name string, dbType DatabaseType, success bool) {
|
||||
result := "failure"
|
||||
if success {
|
||||
result = "success"
|
||||
}
|
||||
|
||||
reconnectAttempts.WithLabelValues(name, string(dbType), result).Inc()
|
||||
}
|
||||
319
pkg/dbmanager/providers/POSTGRES_NOTIFY_LISTEN.md
Normal file
319
pkg/dbmanager/providers/POSTGRES_NOTIFY_LISTEN.md
Normal file
@@ -0,0 +1,319 @@
|
||||
# PostgreSQL NOTIFY/LISTEN Support
|
||||
|
||||
The `dbmanager` package provides built-in support for PostgreSQL's NOTIFY/LISTEN functionality through the `PostgresListener` type.
|
||||
|
||||
## Overview
|
||||
|
||||
PostgreSQL NOTIFY/LISTEN is a simple pub/sub mechanism that allows database clients to:
|
||||
- **LISTEN** on named channels to receive notifications
|
||||
- **NOTIFY** channels to send messages to all listeners
|
||||
- Receive asynchronous notifications without polling
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Subscribe to multiple channels simultaneously
|
||||
- ✅ Callback-based notification handling
|
||||
- ✅ Automatic reconnection on connection loss
|
||||
- ✅ Automatic resubscription after reconnection
|
||||
- ✅ Thread-safe operations
|
||||
- ✅ Panic recovery in notification handlers
|
||||
- ✅ Dedicated connection for listening (doesn't interfere with queries)
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create PostgreSQL provider
|
||||
cfg := &providers.Config{
|
||||
Name: "primary",
|
||||
Type: "postgres",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "myapp",
|
||||
}
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Subscribe to a channel
|
||||
err = listener.Listen("events", func(channel, payload string) {
|
||||
fmt.Printf("Received on %s: %s\n", channel, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
err = listener.Notify(ctx, "events", "Hello, World!")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Keep the program running
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Channels
|
||||
|
||||
```go
|
||||
listener, _ := provider.GetListener(ctx)
|
||||
|
||||
// Listen to different channels with different handlers
|
||||
listener.Listen("user_events", func(channel, payload string) {
|
||||
fmt.Printf("User event: %s\n", payload)
|
||||
})
|
||||
|
||||
listener.Listen("order_events", func(channel, payload string) {
|
||||
fmt.Printf("Order event: %s\n", payload)
|
||||
})
|
||||
|
||||
listener.Listen("payment_events", func(channel, payload string) {
|
||||
fmt.Printf("Payment event: %s\n", payload)
|
||||
})
|
||||
```
|
||||
|
||||
### Unsubscribing
|
||||
|
||||
```go
|
||||
// Stop listening to a specific channel
|
||||
err := listener.Unlisten("user_events")
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to unlisten: %v\n", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Checking Active Channels
|
||||
|
||||
```go
|
||||
// Get list of channels currently being listened to
|
||||
channels := listener.Channels()
|
||||
fmt.Printf("Listening to: %v\n", channels)
|
||||
```
|
||||
|
||||
### Checking Connection Status
|
||||
|
||||
```go
|
||||
if listener.IsConnected() {
|
||||
fmt.Println("Listener is connected")
|
||||
} else {
|
||||
fmt.Println("Listener is disconnected")
|
||||
}
|
||||
```
|
||||
|
||||
## Integration with DBManager
|
||||
|
||||
When using the DBManager, you can access the listener through the PostgreSQL provider:
|
||||
|
||||
```go
|
||||
// Initialize DBManager
|
||||
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
|
||||
mgr.Connect(ctx)
|
||||
defer mgr.Close()
|
||||
|
||||
// Get PostgreSQL connection
|
||||
conn, err := mgr.Get("primary")
|
||||
|
||||
// Note: You'll need to cast to the underlying provider type
|
||||
// This requires exposing the provider through the Connection interface
|
||||
// or providing a helper method
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Cache Invalidation
|
||||
|
||||
```go
|
||||
listener.Listen("cache_invalidation", func(channel, payload string) {
|
||||
// Parse the payload to determine what to invalidate
|
||||
cache.Invalidate(payload)
|
||||
})
|
||||
```
|
||||
|
||||
### Real-time Updates
|
||||
|
||||
```go
|
||||
listener.Listen("data_updates", func(channel, payload string) {
|
||||
// Broadcast update to WebSocket clients
|
||||
websocketBroadcast(payload)
|
||||
})
|
||||
```
|
||||
|
||||
### Configuration Reload
|
||||
|
||||
```go
|
||||
listener.Listen("config_reload", func(channel, payload string) {
|
||||
// Reload application configuration
|
||||
config.Reload()
|
||||
})
|
||||
```
|
||||
|
||||
### Distributed Locking
|
||||
|
||||
```go
|
||||
listener.Listen("lock_released", func(channel, payload string) {
|
||||
// Attempt to acquire the lock
|
||||
tryAcquireLock(payload)
|
||||
})
|
||||
```
|
||||
|
||||
## Automatic Reconnection
|
||||
|
||||
The listener automatically handles connection failures:
|
||||
|
||||
1. When a connection error is detected, the listener initiates reconnection
|
||||
2. Once reconnected, it automatically resubscribes to all previous channels
|
||||
3. Notification handlers remain active throughout the reconnection process
|
||||
|
||||
No manual intervention is required for reconnection.
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handler Panics
|
||||
|
||||
If a notification handler panics, the panic is recovered and logged. The listener continues to function normally:
|
||||
|
||||
```go
|
||||
listener.Listen("events", func(channel, payload string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Handler panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Your event processing logic
|
||||
processEvent(payload)
|
||||
})
|
||||
```
|
||||
|
||||
### Connection Errors
|
||||
|
||||
Connection errors trigger automatic reconnection. Check logs for reconnection events when `EnableLogging` is true.
|
||||
|
||||
## Thread Safety
|
||||
|
||||
All `PostgresListener` methods are thread-safe and can be called concurrently from multiple goroutines.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Dedicated Connection**: The listener uses a dedicated PostgreSQL connection separate from the query connection pool
|
||||
2. **Asynchronous Handlers**: Notification handlers run in separate goroutines to avoid blocking
|
||||
3. **Lightweight**: NOTIFY/LISTEN has minimal overhead compared to polling
|
||||
|
||||
## Comparison with Polling
|
||||
|
||||
| Feature | NOTIFY/LISTEN | Polling |
|
||||
|---------|---------------|---------|
|
||||
| Latency | Low (near real-time) | High (depends on poll interval) |
|
||||
| Database Load | Minimal | High (constant queries) |
|
||||
| Scalability | Excellent | Poor |
|
||||
| Complexity | Simple | Moderate |
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **PostgreSQL Only**: This feature is specific to PostgreSQL and not available for other databases
|
||||
2. **No Message Persistence**: Notifications are not stored; if no listener is connected, the message is lost
|
||||
3. **Payload Limit**: Notification payload is limited to 8000 bytes in PostgreSQL
|
||||
4. **No Guaranteed Delivery**: If a listener disconnects, in-flight notifications may be lost
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep Handlers Fast**: Notification handlers should be quick; for heavy processing, send work to a queue
|
||||
2. **Use JSON Payloads**: Encode structured data as JSON for easy parsing
|
||||
3. **Handle Errors Gracefully**: Always recover from panics in handlers
|
||||
4. **Close Properly**: Always close the provider to ensure the listener is properly shut down
|
||||
5. **Monitor Connection Status**: Use `IsConnected()` for health checks
|
||||
|
||||
## Example: Real-World Application
|
||||
|
||||
```go
|
||||
// Subscribe to various application events
|
||||
listener, _ := provider.GetListener(ctx)
|
||||
|
||||
// User registration events
|
||||
listener.Listen("user_registered", func(channel, payload string) {
|
||||
var event UserRegisteredEvent
|
||||
json.Unmarshal([]byte(payload), &event)
|
||||
|
||||
// Send welcome email
|
||||
sendWelcomeEmail(event.UserID)
|
||||
|
||||
// Invalidate user count cache
|
||||
cache.Delete("user_count")
|
||||
})
|
||||
|
||||
// Order placement events
|
||||
listener.Listen("order_placed", func(channel, payload string) {
|
||||
var event OrderPlacedEvent
|
||||
json.Unmarshal([]byte(payload), &event)
|
||||
|
||||
// Notify warehouse system
|
||||
warehouse.ProcessOrder(event.OrderID)
|
||||
|
||||
// Update inventory cache
|
||||
cache.Invalidate("inventory:" + event.ProductID)
|
||||
})
|
||||
|
||||
// Configuration changes
|
||||
listener.Listen("config_updated", func(channel, payload string) {
|
||||
// Reload configuration from database
|
||||
appConfig.Reload()
|
||||
})
|
||||
```
|
||||
|
||||
## Triggering Notifications from SQL
|
||||
|
||||
You can trigger notifications directly from PostgreSQL triggers or functions:
|
||||
|
||||
```sql
|
||||
-- Example trigger to notify on new user
|
||||
CREATE OR REPLACE FUNCTION notify_user_registered()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
PERFORM pg_notify('user_registered',
|
||||
json_build_object(
|
||||
'user_id', NEW.id,
|
||||
'email', NEW.email,
|
||||
'timestamp', NOW()
|
||||
)::text
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER user_registered_trigger
|
||||
AFTER INSERT ON users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION notify_user_registered();
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [PostgreSQL NOTIFY Documentation](https://www.postgresql.org/docs/current/sql-notify.html)
|
||||
- [PostgreSQL LISTEN Documentation](https://www.postgresql.org/docs/current/sql-listen.html)
|
||||
- [pgx Driver Documentation](https://github.com/jackc/pgx)
|
||||
214
pkg/dbmanager/providers/mongodb.go
Normal file
214
pkg/dbmanager/providers/mongodb.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// MongoProvider implements Provider for MongoDB databases
|
||||
type MongoProvider struct {
|
||||
client *mongo.Client
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewMongoProvider creates a new MongoDB provider
|
||||
func NewMongoProvider() *MongoProvider {
|
||||
return &MongoProvider{}
|
||||
}
|
||||
|
||||
// Connect establishes a MongoDB connection
|
||||
func (p *MongoProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
// Build DSN
|
||||
dsn, err := cfg.BuildDSN()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
// Create client options
|
||||
clientOpts := options.Client().ApplyURI(dsn)
|
||||
|
||||
// Set connection pool size
|
||||
if cfg.GetMaxOpenConns() != nil {
|
||||
maxPoolSize := uint64(*cfg.GetMaxOpenConns())
|
||||
clientOpts.SetMaxPoolSize(maxPoolSize)
|
||||
}
|
||||
|
||||
if cfg.GetMaxIdleConns() != nil {
|
||||
minPoolSize := uint64(*cfg.GetMaxIdleConns())
|
||||
clientOpts.SetMinPoolSize(minPoolSize)
|
||||
}
|
||||
|
||||
// Set timeouts
|
||||
clientOpts.SetConnectTimeout(cfg.GetConnectTimeout())
|
||||
if cfg.GetQueryTimeout() > 0 {
|
||||
clientOpts.SetTimeout(cfg.GetQueryTimeout())
|
||||
}
|
||||
|
||||
// Set read preference if specified
|
||||
if cfg.GetReadPreference() != "" {
|
||||
rp, err := parseReadPreference(cfg.GetReadPreference())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid read preference: %w", err)
|
||||
}
|
||||
clientOpts.SetReadPreference(rp)
|
||||
}
|
||||
|
||||
// Connect with retry logic
|
||||
var client *mongo.Client
|
||||
var lastErr error
|
||||
|
||||
retryAttempts := 3
|
||||
retryDelay := 1 * time.Second
|
||||
|
||||
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("Retrying MongoDB connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Create MongoDB client
|
||||
client, err = mongo.Connect(ctx, clientOpts)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to connect to MongoDB", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ping the database to verify connection
|
||||
pingCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||
err = client.Ping(pingCtx, readpref.Primary())
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
_ = client.Disconnect(ctx)
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to ping MongoDB", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection successful
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||
}
|
||||
|
||||
p.client = client
|
||||
p.config = cfg
|
||||
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("MongoDB connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the MongoDB connection
|
||||
func (p *MongoProvider) Close() error {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := p.client.Disconnect(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close MongoDB connection: %w", err)
|
||||
}
|
||||
|
||||
if p.config.GetEnableLogging() {
|
||||
logger.Info("MongoDB connection closed: name=%s", p.config.GetName())
|
||||
}
|
||||
|
||||
p.client = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the MongoDB connection is alive
|
||||
func (p *MongoProvider) HealthCheck(ctx context.Context) error {
|
||||
if p.client == nil {
|
||||
return fmt.Errorf("MongoDB client is nil")
|
||||
}
|
||||
|
||||
// Use a short timeout for health checks
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.client.Ping(healthCtx, readpref.Primary()); err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns an error for MongoDB (not a SQL database)
|
||||
func (p *MongoProvider) GetNative() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// GetMongo returns the MongoDB client
|
||||
func (p *MongoProvider) GetMongo() (*mongo.Client, error) {
|
||||
if p.client == nil {
|
||||
return nil, fmt.Errorf("MongoDB client is not initialized")
|
||||
}
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
// Stats returns connection statistics for MongoDB
|
||||
func (p *MongoProvider) Stats() *ConnectionStats {
|
||||
if p.client == nil {
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "mongodb",
|
||||
Connected: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MongoDB doesn't expose detailed connection pool stats like sql.DB
|
||||
// We return basic stats
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "mongodb",
|
||||
Connected: true,
|
||||
}
|
||||
}
|
||||
|
||||
// parseReadPreference parses a read preference string into a readpref.ReadPref
|
||||
func parseReadPreference(rp string) (*readpref.ReadPref, error) {
|
||||
switch rp {
|
||||
case "primary":
|
||||
return readpref.Primary(), nil
|
||||
case "primaryPreferred":
|
||||
return readpref.PrimaryPreferred(), nil
|
||||
case "secondary":
|
||||
return readpref.Secondary(), nil
|
||||
case "secondaryPreferred":
|
||||
return readpref.SecondaryPreferred(), nil
|
||||
case "nearest":
|
||||
return readpref.Nearest(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown read preference: %s", rp)
|
||||
}
|
||||
}
|
||||
184
pkg/dbmanager/providers/mssql.go
Normal file
184
pkg/dbmanager/providers/mssql.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// MSSQLProvider implements Provider for Microsoft SQL Server databases
|
||||
type MSSQLProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewMSSQLProvider creates a new MSSQL provider
|
||||
func NewMSSQLProvider() *MSSQLProvider {
|
||||
return &MSSQLProvider{}
|
||||
}
|
||||
|
||||
// Connect establishes a MSSQL connection
|
||||
func (p *MSSQLProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
// Build DSN
|
||||
dsn, err := cfg.BuildDSN()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
// Connect with retry logic
|
||||
var db *sql.DB
|
||||
var lastErr error
|
||||
|
||||
retryAttempts := 3 // Default retry attempts
|
||||
retryDelay := 1 * time.Second
|
||||
|
||||
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("Retrying MSSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err = sql.Open("sqlserver", dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to open MSSQL connection", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Test the connection with context timeout
|
||||
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||
err = db.PingContext(connectCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
db.Close()
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to ping MSSQL database", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection successful
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
if cfg.GetMaxOpenConns() != nil {
|
||||
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||
}
|
||||
if cfg.GetMaxIdleConns() != nil {
|
||||
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||
}
|
||||
if cfg.GetConnMaxLifetime() != nil {
|
||||
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||
}
|
||||
if cfg.GetConnMaxIdleTime() != nil {
|
||||
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||
}
|
||||
|
||||
p.db = db
|
||||
p.config = cfg
|
||||
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("MSSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the MSSQL connection
|
||||
func (p *MSSQLProvider) Close() error {
|
||||
if p.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close MSSQL connection: %w", err)
|
||||
}
|
||||
|
||||
if p.config.GetEnableLogging() {
|
||||
logger.Info("MSSQL connection closed: name=%s", p.config.GetName())
|
||||
}
|
||||
|
||||
p.db = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the MSSQL connection is alive
|
||||
func (p *MSSQLProvider) HealthCheck(ctx context.Context) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Use a short timeout for health checks
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.db.PingContext(healthCtx); err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *MSSQLProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("database connection is not initialized")
|
||||
}
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
// GetMongo returns an error for MSSQL (not a MongoDB connection)
|
||||
func (p *MSSQLProvider) GetMongo() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection pool statistics
|
||||
func (p *MSSQLProvider) Stats() *ConnectionStats {
|
||||
if p.db == nil {
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "mssql",
|
||||
Connected: false,
|
||||
}
|
||||
}
|
||||
|
||||
stats := p.db.Stats()
|
||||
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "mssql",
|
||||
Connected: true,
|
||||
OpenConnections: stats.OpenConnections,
|
||||
InUse: stats.InUse,
|
||||
Idle: stats.Idle,
|
||||
WaitCount: stats.WaitCount,
|
||||
WaitDuration: stats.WaitDuration,
|
||||
MaxIdleClosed: stats.MaxIdleClosed,
|
||||
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||
}
|
||||
}
|
||||
231
pkg/dbmanager/providers/postgres.go
Normal file
231
pkg/dbmanager/providers/postgres.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// PostgresProvider implements Provider for PostgreSQL databases
|
||||
type PostgresProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
listener *PostgresListener
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewPostgresProvider creates a new PostgreSQL provider
|
||||
func NewPostgresProvider() *PostgresProvider {
|
||||
return &PostgresProvider{}
|
||||
}
|
||||
|
||||
// Connect establishes a PostgreSQL connection
|
||||
func (p *PostgresProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
// Build DSN
|
||||
dsn, err := cfg.BuildDSN()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
// Connect with retry logic
|
||||
var db *sql.DB
|
||||
var lastErr error
|
||||
|
||||
retryAttempts := 3 // Default retry attempts
|
||||
retryDelay := 1 * time.Second
|
||||
|
||||
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("Retrying PostgreSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err = sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to open PostgreSQL connection", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Test the connection with context timeout
|
||||
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||
err = db.PingContext(connectCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
db.Close()
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to ping PostgreSQL database", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection successful
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
if cfg.GetMaxOpenConns() != nil {
|
||||
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||
}
|
||||
if cfg.GetMaxIdleConns() != nil {
|
||||
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||
}
|
||||
if cfg.GetConnMaxLifetime() != nil {
|
||||
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||
}
|
||||
if cfg.GetConnMaxIdleTime() != nil {
|
||||
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||
}
|
||||
|
||||
p.db = db
|
||||
p.config = cfg
|
||||
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("PostgreSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the PostgreSQL connection
|
||||
func (p *PostgresProvider) Close() error {
|
||||
// Close listener if it exists
|
||||
p.mu.Lock()
|
||||
if p.listener != nil {
|
||||
if err := p.listener.Close(); err != nil {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("failed to close listener: %w", err)
|
||||
}
|
||||
p.listener = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if p.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
if p.config.GetEnableLogging() {
|
||||
logger.Info("PostgreSQL connection closed: name=%s", p.config.GetName())
|
||||
}
|
||||
|
||||
p.db = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the PostgreSQL connection is alive
|
||||
func (p *PostgresProvider) HealthCheck(ctx context.Context) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Use a short timeout for health checks
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.db.PingContext(healthCtx); err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *PostgresProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("database connection is not initialized")
|
||||
}
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
// GetMongo returns an error for PostgreSQL (not a MongoDB connection)
|
||||
func (p *PostgresProvider) GetMongo() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection pool statistics
|
||||
func (p *PostgresProvider) Stats() *ConnectionStats {
|
||||
if p.db == nil {
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "postgres",
|
||||
Connected: false,
|
||||
}
|
||||
}
|
||||
|
||||
stats := p.db.Stats()
|
||||
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "postgres",
|
||||
Connected: true,
|
||||
OpenConnections: stats.OpenConnections,
|
||||
InUse: stats.InUse,
|
||||
Idle: stats.Idle,
|
||||
WaitCount: stats.WaitCount,
|
||||
WaitDuration: stats.WaitDuration,
|
||||
MaxIdleClosed: stats.MaxIdleClosed,
|
||||
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||
}
|
||||
}
|
||||
|
||||
// GetListener returns a PostgreSQL listener for NOTIFY/LISTEN functionality
|
||||
// The listener is lazily initialized on first call and reused for subsequent calls
|
||||
func (p *PostgresProvider) GetListener(ctx context.Context) (*PostgresListener, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Return existing listener if already created
|
||||
if p.listener != nil {
|
||||
return p.listener, nil
|
||||
}
|
||||
|
||||
// Create new listener
|
||||
listener := NewPostgresListener(p.config)
|
||||
|
||||
// Connect the listener
|
||||
if err := listener.Connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect listener: %w", err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
return p.listener, nil
|
||||
}
|
||||
|
||||
// calculateBackoff calculates exponential backoff delay
|
||||
func calculateBackoff(attempt int, initial, maxDelay time.Duration) time.Duration {
|
||||
delay := initial * time.Duration(math.Pow(2, float64(attempt)))
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
401
pkg/dbmanager/providers/postgres_listener.go
Normal file
401
pkg/dbmanager/providers/postgres_listener.go
Normal file
@@ -0,0 +1,401 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// NotificationHandler is called when a notification is received
|
||||
type NotificationHandler func(channel string, payload string)
|
||||
|
||||
// PostgresListener manages PostgreSQL LISTEN/NOTIFY functionality
|
||||
type PostgresListener struct {
|
||||
config ConnectionConfig
|
||||
conn *pgx.Conn
|
||||
|
||||
// Channel subscriptions
|
||||
channels map[string]NotificationHandler
|
||||
mu sync.RWMutex
|
||||
|
||||
// Lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closed bool
|
||||
closeMu sync.Mutex
|
||||
reconnectC chan struct{}
|
||||
}
|
||||
|
||||
// NewPostgresListener creates a new PostgreSQL listener
|
||||
func NewPostgresListener(cfg ConnectionConfig) *PostgresListener {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &PostgresListener{
|
||||
config: cfg,
|
||||
channels: make(map[string]NotificationHandler),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
reconnectC: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a dedicated connection for listening
|
||||
func (l *PostgresListener) Connect(ctx context.Context) error {
|
||||
dsn, err := l.config.BuildDSN()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
// Parse connection config
|
||||
connConfig, err := pgx.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse connection config: %w", err)
|
||||
}
|
||||
|
||||
// Connect with retry logic
|
||||
var conn *pgx.Conn
|
||||
var lastErr error
|
||||
|
||||
retryAttempts := 3
|
||||
retryDelay := 1 * time.Second
|
||||
|
||||
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("Retrying PostgreSQL listener connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
conn, err = pgx.ConnectConfig(ctx, connConfig)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Warn("Failed to connect PostgreSQL listener", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Test the connection
|
||||
if err = conn.Ping(ctx); err != nil {
|
||||
lastErr = err
|
||||
conn.Close(ctx)
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Warn("Failed to ping PostgreSQL listener", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection successful
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect listener after %d attempts: %w", retryAttempts, lastErr)
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
l.conn = conn
|
||||
l.mu.Unlock()
|
||||
|
||||
// Start notification handler
|
||||
go l.handleNotifications()
|
||||
|
||||
// Start reconnection handler
|
||||
go l.handleReconnection()
|
||||
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("PostgreSQL listener connected: name=%s", l.config.GetName())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Listen subscribes to a PostgreSQL notification channel
|
||||
func (l *PostgresListener) Listen(channel string, handler NotificationHandler) error {
|
||||
l.closeMu.Lock()
|
||||
if l.closed {
|
||||
l.closeMu.Unlock()
|
||||
return fmt.Errorf("listener is closed")
|
||||
}
|
||||
l.closeMu.Unlock()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.conn == nil {
|
||||
return fmt.Errorf("listener connection is not initialized")
|
||||
}
|
||||
|
||||
// Execute LISTEN command
|
||||
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("LISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on channel %s: %w", channel, err)
|
||||
}
|
||||
|
||||
// Store the handler
|
||||
l.channels[channel] = handler
|
||||
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("Listening on channel: name=%s, channel=%s", l.config.GetName(), channel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlisten unsubscribes from a PostgreSQL notification channel
|
||||
func (l *PostgresListener) Unlisten(channel string) error {
|
||||
l.closeMu.Lock()
|
||||
if l.closed {
|
||||
l.closeMu.Unlock()
|
||||
return fmt.Errorf("listener is closed")
|
||||
}
|
||||
l.closeMu.Unlock()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.conn == nil {
|
||||
return fmt.Errorf("listener connection is not initialized")
|
||||
}
|
||||
|
||||
// Execute UNLISTEN command
|
||||
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlisten from channel %s: %w", channel, err)
|
||||
}
|
||||
|
||||
// Remove the handler
|
||||
delete(l.channels, channel)
|
||||
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("Unlistened from channel: name=%s, channel=%s", l.config.GetName(), channel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify sends a notification to a PostgreSQL channel
|
||||
func (l *PostgresListener) Notify(ctx context.Context, channel string, payload string) error {
|
||||
l.closeMu.Lock()
|
||||
if l.closed {
|
||||
l.closeMu.Unlock()
|
||||
return fmt.Errorf("listener is closed")
|
||||
}
|
||||
l.closeMu.Unlock()
|
||||
|
||||
l.mu.RLock()
|
||||
conn := l.conn
|
||||
l.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("listener connection is not initialized")
|
||||
}
|
||||
|
||||
// Execute NOTIFY command
|
||||
_, err := conn.Exec(ctx, "SELECT pg_notify($1, $2)", channel, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to notify channel %s: %w", channel, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the listener and all subscriptions
|
||||
func (l *PostgresListener) Close() error {
|
||||
l.closeMu.Lock()
|
||||
if l.closed {
|
||||
l.closeMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.closed = true
|
||||
l.closeMu.Unlock()
|
||||
|
||||
// Cancel context to stop background goroutines
|
||||
l.cancel()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlisten from all channels
|
||||
for channel := range l.channels {
|
||||
_, _ = l.conn.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||
}
|
||||
|
||||
// Close connection
|
||||
err := l.conn.Close(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close listener connection: %w", err)
|
||||
}
|
||||
|
||||
l.conn = nil
|
||||
l.channels = make(map[string]NotificationHandler)
|
||||
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("PostgreSQL listener closed: name=%s", l.config.GetName())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNotifications processes incoming notifications
|
||||
func (l *PostgresListener) handleNotifications() {
|
||||
for {
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
l.mu.RLock()
|
||||
conn := l.conn
|
||||
l.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
// Connection not available, wait for reconnection
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Wait for notification with timeout
|
||||
ctx, cancel := context.WithTimeout(l.ctx, 5*time.Second)
|
||||
notification, err := conn.WaitForNotification(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
// Check if context was cancelled
|
||||
if l.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's a connection error
|
||||
if pgconn.Timeout(err) {
|
||||
// Timeout is normal, continue waiting
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection error, trigger reconnection
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Warn("Notification error, triggering reconnection", "error", err)
|
||||
}
|
||||
select {
|
||||
case l.reconnectC <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process notification
|
||||
l.mu.RLock()
|
||||
handler, exists := l.channels[notification.Channel]
|
||||
l.mu.RUnlock()
|
||||
|
||||
if exists && handler != nil {
|
||||
// Call handler in a goroutine to avoid blocking
|
||||
go func(ch, payload string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Error("Notification handler panic: channel=%s, error=%v", ch, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
handler(ch, payload)
|
||||
}(notification.Channel, notification.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleReconnection manages automatic reconnection
|
||||
func (l *PostgresListener) handleReconnection() {
|
||||
for {
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
case <-l.reconnectC:
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("Attempting to reconnect listener: name=%s", l.config.GetName())
|
||||
}
|
||||
|
||||
// Close existing connection
|
||||
l.mu.Lock()
|
||||
if l.conn != nil {
|
||||
l.conn.Close(context.Background())
|
||||
l.conn = nil
|
||||
}
|
||||
|
||||
// Save current subscriptions
|
||||
channels := make(map[string]NotificationHandler)
|
||||
for ch, handler := range l.channels {
|
||||
channels[ch] = handler
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
// Attempt reconnection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
err := l.Connect(ctx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Error("Failed to reconnect listener: name=%s, error=%v", l.config.GetName(), err)
|
||||
}
|
||||
// Retry after delay
|
||||
time.Sleep(5 * time.Second)
|
||||
select {
|
||||
case l.reconnectC <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Resubscribe to all channels
|
||||
for channel, handler := range channels {
|
||||
if err := l.Listen(channel, handler); err != nil {
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Error("Failed to resubscribe to channel: name=%s, channel=%s, error=%v", l.config.GetName(), channel, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l.config.GetEnableLogging() {
|
||||
logger.Info("Listener reconnected successfully: name=%s", l.config.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns true if the listener is connected
|
||||
func (l *PostgresListener) IsConnected() bool {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
return l.conn != nil
|
||||
}
|
||||
|
||||
// Channels returns the list of channels currently being listened to
|
||||
func (l *PostgresListener) Channels() []string {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
channels := make([]string, 0, len(l.channels))
|
||||
for ch := range l.channels {
|
||||
channels = append(channels, ch)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
228
pkg/dbmanager/providers/postgres_listener_example_test.go
Normal file
228
pkg/dbmanager/providers/postgres_listener_example_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package providers_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
// ExamplePostgresListener_basic demonstrates basic LISTEN/NOTIFY usage
|
||||
func ExamplePostgresListener_basic() {
|
||||
// Create a connection config
|
||||
cfg := &dbmanager.ConnectionConfig{
|
||||
Name: "example",
|
||||
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "testdb",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
EnableLogging: true,
|
||||
}
|
||||
|
||||
// Create and connect PostgreSQL provider
|
||||
provider := providers.NewPostgresProvider()
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
}
|
||||
|
||||
// Subscribe to a channel with a handler
|
||||
err = listener.Listen("user_events", func(channel, payload string) {
|
||||
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen: %v", err))
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to notify: %v", err))
|
||||
}
|
||||
|
||||
// Wait for notification to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Unsubscribe from the channel
|
||||
if err := listener.Unlisten("user_events"); err != nil {
|
||||
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
// ExamplePostgresListener_multipleChannels demonstrates listening to multiple channels
|
||||
func ExamplePostgresListener_multipleChannels() {
|
||||
cfg := &dbmanager.ConnectionConfig{
|
||||
Name: "example",
|
||||
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "testdb",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
EnableLogging: false,
|
||||
}
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
}
|
||||
|
||||
// Listen to multiple channels
|
||||
channels := []string{"orders", "payments", "notifications"}
|
||||
for _, ch := range channels {
|
||||
channel := ch // Capture for closure
|
||||
err := listener.Listen(channel, func(ch, payload string) {
|
||||
fmt.Printf("[%s] %s\n", ch, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
||||
}
|
||||
}
|
||||
|
||||
// Send notifications to different channels
|
||||
listener.Notify(ctx, "orders", "New order #12345")
|
||||
listener.Notify(ctx, "payments", "Payment received $99.99")
|
||||
listener.Notify(ctx, "notifications", "Welcome email sent")
|
||||
|
||||
// Wait for notifications
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Check active channels
|
||||
activeChannels := listener.Channels()
|
||||
fmt.Printf("Listening to %d channels: %v\n", len(activeChannels), activeChannels)
|
||||
}
|
||||
|
||||
// ExamplePostgresListener_withDBManager demonstrates usage with DBManager
|
||||
func ExamplePostgresListener_withDBManager() {
|
||||
// This example shows how to use the listener with the full DBManager
|
||||
|
||||
// Assume we have a DBManager instance and get a connection
|
||||
// conn, _ := dbMgr.Get("primary")
|
||||
|
||||
// Get the underlying provider (this would need to be exposed via the Connection interface)
|
||||
// For now, this is a conceptual example
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create provider directly for demonstration
|
||||
cfg := &dbmanager.ConnectionConfig{
|
||||
Name: "primary",
|
||||
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "myapp",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Subscribe to application events
|
||||
listener.Listen("cache_invalidation", func(channel, payload string) {
|
||||
fmt.Printf("Cache invalidation request: %s\n", payload)
|
||||
// Handle cache invalidation logic here
|
||||
})
|
||||
|
||||
listener.Listen("config_reload", func(channel, payload string) {
|
||||
fmt.Printf("Configuration reload request: %s\n", payload)
|
||||
// Handle configuration reload logic here
|
||||
})
|
||||
|
||||
// Simulate receiving notifications
|
||||
listener.Notify(ctx, "cache_invalidation", "user:123")
|
||||
listener.Notify(ctx, "config_reload", "database")
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// ExamplePostgresListener_errorHandling demonstrates error handling and reconnection
|
||||
func ExamplePostgresListener_errorHandling() {
|
||||
cfg := &dbmanager.ConnectionConfig{
|
||||
Name: "example",
|
||||
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "postgres",
|
||||
Password: "password",
|
||||
Database: "testdb",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
EnableLogging: true,
|
||||
}
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
}
|
||||
|
||||
// The listener automatically reconnects if the connection is lost
|
||||
// Subscribe with error handling in the callback
|
||||
err = listener.Listen("critical_events", func(channel, payload string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Printf("Handler panic recovered: %v\n", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Process the event
|
||||
fmt.Printf("Processing critical event: %s\n", payload)
|
||||
|
||||
// If processing fails, the panic will be caught by the defer above
|
||||
// The listener will continue to function normally
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to listen: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if listener is connected
|
||||
if listener.IsConnected() {
|
||||
fmt.Println("Listener is connected and ready")
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
listener.Notify(ctx, "critical_events", "system_alert")
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
83
pkg/dbmanager/providers/provider.go
Normal file
83
pkg/dbmanager/providers/provider.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||
ErrNotSQLDatabase = errors.New("not a SQL database")
|
||||
|
||||
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
|
||||
ErrNotMongoDB = errors.New("not a MongoDB connection")
|
||||
)
|
||||
|
||||
// ConnectionStats contains statistics about a database connection
|
||||
type ConnectionStats struct {
|
||||
Name string
|
||||
Type string // Database type as string to avoid circular dependency
|
||||
Connected bool
|
||||
LastHealthCheck time.Time
|
||||
HealthCheckStatus string
|
||||
|
||||
// SQL connection pool stats
|
||||
OpenConnections int
|
||||
InUse int
|
||||
Idle int
|
||||
WaitCount int64
|
||||
WaitDuration time.Duration
|
||||
MaxIdleClosed int64
|
||||
MaxLifetimeClosed int64
|
||||
}
|
||||
|
||||
// ConnectionConfig is a minimal interface for configuration
|
||||
// The actual implementation is in dbmanager package
|
||||
type ConnectionConfig interface {
|
||||
BuildDSN() (string, error)
|
||||
GetName() string
|
||||
GetType() string
|
||||
GetHost() string
|
||||
GetPort() int
|
||||
GetUser() string
|
||||
GetPassword() string
|
||||
GetDatabase() string
|
||||
GetFilePath() string
|
||||
GetConnectTimeout() time.Duration
|
||||
GetQueryTimeout() time.Duration
|
||||
GetEnableLogging() bool
|
||||
GetEnableMetrics() bool
|
||||
GetMaxOpenConns() *int
|
||||
GetMaxIdleConns() *int
|
||||
GetConnMaxLifetime() *time.Duration
|
||||
GetConnMaxIdleTime() *time.Duration
|
||||
GetReadPreference() string
|
||||
}
|
||||
|
||||
// Provider creates and manages the underlying database connection
|
||||
type Provider interface {
|
||||
// Connect establishes the database connection
|
||||
Connect(ctx context.Context, cfg ConnectionConfig) error
|
||||
|
||||
// Close closes the connection
|
||||
Close() error
|
||||
|
||||
// HealthCheck verifies the connection is alive
|
||||
HealthCheck(ctx context.Context) error
|
||||
|
||||
// GetNative returns the native *sql.DB (SQL databases only)
|
||||
// Returns an error for non-SQL databases
|
||||
GetNative() (*sql.DB, error)
|
||||
|
||||
// GetMongo returns the MongoDB client (MongoDB only)
|
||||
// Returns an error for non-MongoDB databases
|
||||
GetMongo() (*mongo.Client, error)
|
||||
|
||||
// Stats returns connection statistics
|
||||
Stats() *ConnectionStats
|
||||
}
|
||||
177
pkg/dbmanager/providers/sqlite.go
Normal file
177
pkg/dbmanager/providers/sqlite.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// SQLiteProvider implements Provider for SQLite databases
|
||||
type SQLiteProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewSQLiteProvider creates a new SQLite provider
|
||||
func NewSQLiteProvider() *SQLiteProvider {
|
||||
return &SQLiteProvider{}
|
||||
}
|
||||
|
||||
// Connect establishes a SQLite connection
|
||||
func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
// Build DSN
|
||||
dsn, err := cfg.BuildDSN()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build DSN: %w", err)
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open SQLite connection: %w", err)
|
||||
}
|
||||
|
||||
// Test the connection with context timeout
|
||||
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||
err = db.PingContext(connectCtx)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to ping SQLite database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
// Note: SQLite works best with MaxOpenConns=1 for write operations
|
||||
// but can handle multiple readers
|
||||
if cfg.GetMaxOpenConns() != nil {
|
||||
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||
} else {
|
||||
// Default to 1 for SQLite to avoid "database is locked" errors
|
||||
db.SetMaxOpenConns(1)
|
||||
}
|
||||
|
||||
if cfg.GetMaxIdleConns() != nil {
|
||||
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||
}
|
||||
if cfg.GetConnMaxLifetime() != nil {
|
||||
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||
}
|
||||
if cfg.GetConnMaxIdleTime() != nil {
|
||||
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||
}
|
||||
|
||||
// Enable WAL mode for better concurrent access
|
||||
_, err = db.ExecContext(ctx, "PRAGMA journal_mode=WAL")
|
||||
if err != nil {
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to enable WAL mode for SQLite", "error", err)
|
||||
}
|
||||
// Don't fail connection if WAL mode cannot be enabled
|
||||
}
|
||||
|
||||
// Set busy timeout to handle locked database
|
||||
_, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000")
|
||||
if err != nil {
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.db = db
|
||||
p.config = cfg
|
||||
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Info("SQLite connection established: name=%s, filepath=%s", cfg.GetName(), cfg.GetFilePath())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the SQLite connection
|
||||
func (p *SQLiteProvider) Close() error {
|
||||
if p.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close SQLite connection: %w", err)
|
||||
}
|
||||
|
||||
if p.config.GetEnableLogging() {
|
||||
logger.Info("SQLite connection closed: name=%s", p.config.GetName())
|
||||
}
|
||||
|
||||
p.db = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck verifies the SQLite connection is alive
|
||||
func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Use a short timeout for health checks
|
||||
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Execute a simple query to verify the database is accessible
|
||||
var result int
|
||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
return fmt.Errorf("health check returned unexpected result: %d", result)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("database connection is not initialized")
|
||||
}
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
// GetMongo returns an error for SQLite (not a MongoDB connection)
|
||||
func (p *SQLiteProvider) GetMongo() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection pool statistics
|
||||
func (p *SQLiteProvider) Stats() *ConnectionStats {
|
||||
if p.db == nil {
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "sqlite",
|
||||
Connected: false,
|
||||
}
|
||||
}
|
||||
|
||||
stats := p.db.Stats()
|
||||
|
||||
return &ConnectionStats{
|
||||
Name: p.config.GetName(),
|
||||
Type: "sqlite",
|
||||
Connected: true,
|
||||
OpenConnections: stats.OpenConnections,
|
||||
InUse: stats.InUse,
|
||||
Idle: stats.Idle,
|
||||
WaitCount: stats.WaitCount,
|
||||
WaitDuration: stats.WaitDuration,
|
||||
MaxIdleClosed: stats.MaxIdleClosed,
|
||||
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||
}
|
||||
}
|
||||
150
pkg/errortracking/README.md
Normal file
150
pkg/errortracking/README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# Error Tracking
|
||||
|
||||
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
|
||||
|
||||
## Features
|
||||
|
||||
- **Provider Interface**: Flexible design supporting multiple error tracking backends
|
||||
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
|
||||
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
|
||||
- **Panic Tracking**: Automatic panic capture with stack traces
|
||||
- **NoOp Provider**: Zero-overhead when error tracking is disabled
|
||||
|
||||
## Configuration
|
||||
|
||||
Add error tracking configuration to your config file:
|
||||
|
||||
```yaml
|
||||
error_tracking:
|
||||
enabled: true
|
||||
provider: "sentry" # Currently supports: "sentry" or "noop"
|
||||
dsn: "https://your-sentry-dsn@sentry.io/project-id"
|
||||
environment: "production" # e.g., production, staging, development
|
||||
release: "v1.0.0" # Your application version
|
||||
debug: false
|
||||
sample_rate: 1.0 # Error sample rate (0.0-1.0)
|
||||
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Initialization
|
||||
|
||||
Initialize error tracking in your application startup:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load your configuration
|
||||
cfg := config.Config{
|
||||
ErrorTracking: config.ErrorTrackingConfig{
|
||||
Enabled: true,
|
||||
Provider: "sentry",
|
||||
DSN: "https://your-sentry-dsn@sentry.io/project-id",
|
||||
Environment: "production",
|
||||
Release: "v1.0.0",
|
||||
SampleRate: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
logger.Init(false)
|
||||
|
||||
// Initialize error tracking
|
||||
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
|
||||
if err != nil {
|
||||
logger.Error("Failed to initialize error tracking: %v", err)
|
||||
} else {
|
||||
logger.InitErrorTracking(provider)
|
||||
}
|
||||
|
||||
// Your application code...
|
||||
|
||||
// Cleanup on shutdown
|
||||
defer logger.CloseErrorTracking()
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Tracking
|
||||
|
||||
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
|
||||
|
||||
```go
|
||||
// This will be logged AND sent to Sentry
|
||||
logger.Error("Database connection failed: %v", err)
|
||||
|
||||
// This will also be logged AND sent to Sentry
|
||||
logger.Warn("Cache miss for key: %s", key)
|
||||
```
|
||||
|
||||
### Panic Tracking
|
||||
|
||||
Panics are automatically captured when using the logger's panic handlers:
|
||||
|
||||
```go
|
||||
// Using CatchPanic
|
||||
defer logger.CatchPanic("MyFunction")()
|
||||
|
||||
// Using CatchPanicCallback
|
||||
defer logger.CatchPanicCallback("MyFunction", func(err any) {
|
||||
// Custom cleanup
|
||||
})()
|
||||
|
||||
// Using HandlePanic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("MyMethod", r)
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
### Manual Tracking
|
||||
|
||||
You can also use the provider directly for custom error tracking:
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
func someFunction() {
|
||||
tracker := logger.GetErrorTracker()
|
||||
if tracker != nil {
|
||||
// Capture an error
|
||||
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"request_id": requestID,
|
||||
})
|
||||
|
||||
// Capture a message
|
||||
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
|
||||
"event_type": "user_signup",
|
||||
})
|
||||
|
||||
// Capture a panic
|
||||
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
|
||||
"context": "background_job",
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Severity Levels
|
||||
|
||||
The package supports the following severity levels:
|
||||
|
||||
- `SeverityError`: For errors that should be tracked and investigated
|
||||
- `SeverityWarning`: For warnings that may indicate potential issues
|
||||
- `SeverityInfo`: For informational messages
|
||||
- `SeverityDebug`: For debug-level information
|
||||
|
||||
```
|
||||
67
pkg/errortracking/errortracking_test.go
Normal file
67
pkg/errortracking/errortracking_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package errortracking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNoOpProvider(t *testing.T) {
|
||||
provider := NewNoOpProvider()
|
||||
|
||||
// Test that all methods can be called without panicking
|
||||
t.Run("CaptureError", func(t *testing.T) {
|
||||
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
|
||||
})
|
||||
|
||||
t.Run("CaptureMessage", func(t *testing.T) {
|
||||
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
|
||||
})
|
||||
|
||||
t.Run("CapturePanic", func(t *testing.T) {
|
||||
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
|
||||
})
|
||||
|
||||
t.Run("Flush", func(t *testing.T) {
|
||||
result := provider.Flush(5)
|
||||
if !result {
|
||||
t.Error("Expected Flush to return true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close", func(t *testing.T) {
|
||||
err := provider.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Close to return nil, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSeverityLevels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
severity Severity
|
||||
expected string
|
||||
}{
|
||||
{"Error", SeverityError, "error"},
|
||||
{"Warning", SeverityWarning, "warning"},
|
||||
{"Info", SeverityInfo, "info"},
|
||||
{"Debug", SeverityDebug, "debug"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if string(tt.severity) != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderInterface(t *testing.T) {
|
||||
// Test that NoOpProvider implements Provider interface
|
||||
var _ Provider = (*NoOpProvider)(nil)
|
||||
|
||||
// Test that SentryProvider implements Provider interface
|
||||
var _ Provider = (*SentryProvider)(nil)
|
||||
}
|
||||
33
pkg/errortracking/factory.go
Normal file
33
pkg/errortracking/factory.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package errortracking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// NewProviderFromConfig creates an error tracking provider based on the configuration
|
||||
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
|
||||
if !cfg.Enabled {
|
||||
return NewNoOpProvider(), nil
|
||||
}
|
||||
|
||||
switch cfg.Provider {
|
||||
case "sentry":
|
||||
if cfg.DSN == "" {
|
||||
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
|
||||
}
|
||||
return NewSentryProvider(SentryConfig{
|
||||
DSN: cfg.DSN,
|
||||
Environment: cfg.Environment,
|
||||
Release: cfg.Release,
|
||||
Debug: cfg.Debug,
|
||||
SampleRate: cfg.SampleRate,
|
||||
TracesSampleRate: cfg.TracesSampleRate,
|
||||
})
|
||||
case "noop", "":
|
||||
return NewNoOpProvider(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
33
pkg/errortracking/interfaces.go
Normal file
33
pkg/errortracking/interfaces.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package errortracking
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Severity represents the severity level of an error
|
||||
type Severity string
|
||||
|
||||
const (
|
||||
SeverityError Severity = "error"
|
||||
SeverityWarning Severity = "warning"
|
||||
SeverityInfo Severity = "info"
|
||||
SeverityDebug Severity = "debug"
|
||||
)
|
||||
|
||||
// Provider defines the interface for error tracking providers
|
||||
type Provider interface {
|
||||
// CaptureError captures an error with the given severity and additional context
|
||||
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
|
||||
|
||||
// CaptureMessage captures a message with the given severity and additional context
|
||||
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
|
||||
|
||||
// CapturePanic captures a panic with stack trace
|
||||
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
|
||||
|
||||
// Flush waits for all events to be sent (useful for graceful shutdown)
|
||||
Flush(timeout int) bool
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
Close() error
|
||||
}
|
||||
37
pkg/errortracking/noop.go
Normal file
37
pkg/errortracking/noop.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package errortracking
|
||||
|
||||
import "context"
|
||||
|
||||
// NoOpProvider is a no-op implementation of the Provider interface
|
||||
// Used when error tracking is disabled
|
||||
type NoOpProvider struct{}
|
||||
|
||||
// NewNoOpProvider creates a new NoOp provider
|
||||
func NewNoOpProvider() *NoOpProvider {
|
||||
return &NoOpProvider{}
|
||||
}
|
||||
|
||||
// CaptureError does nothing
|
||||
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
|
||||
// No-op
|
||||
}
|
||||
|
||||
// CaptureMessage does nothing
|
||||
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
|
||||
// No-op
|
||||
}
|
||||
|
||||
// CapturePanic does nothing
|
||||
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
|
||||
// No-op
|
||||
}
|
||||
|
||||
// Flush does nothing and returns true
|
||||
func (n *NoOpProvider) Flush(timeout int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Close does nothing
|
||||
func (n *NoOpProvider) Close() error {
|
||||
return nil
|
||||
}
|
||||
154
pkg/errortracking/sentry.go
Normal file
154
pkg/errortracking/sentry.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package errortracking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
)
|
||||
|
||||
// SentryProvider implements the Provider interface using Sentry
|
||||
type SentryProvider struct {
|
||||
hub *sentry.Hub
|
||||
}
|
||||
|
||||
// SentryConfig holds the configuration for Sentry
|
||||
type SentryConfig struct {
|
||||
DSN string
|
||||
Environment string
|
||||
Release string
|
||||
Debug bool
|
||||
SampleRate float64
|
||||
TracesSampleRate float64
|
||||
}
|
||||
|
||||
// NewSentryProvider creates a new Sentry provider
|
||||
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: config.DSN,
|
||||
Environment: config.Environment,
|
||||
Release: config.Release,
|
||||
Debug: config.Debug,
|
||||
AttachStacktrace: true,
|
||||
SampleRate: config.SampleRate,
|
||||
TracesSampleRate: config.TracesSampleRate,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
|
||||
}
|
||||
|
||||
return &SentryProvider{
|
||||
hub: sentry.CurrentHub(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CaptureError captures an error with the given severity and additional context
|
||||
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hub := sentry.GetHubFromContext(ctx)
|
||||
if hub == nil {
|
||||
hub = s.hub
|
||||
}
|
||||
|
||||
event := sentry.NewEvent()
|
||||
event.Level = s.convertSeverity(severity)
|
||||
event.Message = err.Error()
|
||||
event.Exception = []sentry.Exception{
|
||||
{
|
||||
Value: err.Error(),
|
||||
Type: fmt.Sprintf("%T", err),
|
||||
Stacktrace: sentry.ExtractStacktrace(err),
|
||||
},
|
||||
}
|
||||
|
||||
if extra != nil {
|
||||
event.Extra = extra
|
||||
}
|
||||
|
||||
hub.CaptureEvent(event)
|
||||
}
|
||||
|
||||
// CaptureMessage captures a message with the given severity and additional context
|
||||
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
|
||||
if message == "" {
|
||||
return
|
||||
}
|
||||
|
||||
hub := sentry.GetHubFromContext(ctx)
|
||||
if hub == nil {
|
||||
hub = s.hub
|
||||
}
|
||||
|
||||
event := sentry.NewEvent()
|
||||
event.Level = s.convertSeverity(severity)
|
||||
event.Message = message
|
||||
|
||||
if extra != nil {
|
||||
event.Extra = extra
|
||||
}
|
||||
|
||||
hub.CaptureEvent(event)
|
||||
}
|
||||
|
||||
// CapturePanic captures a panic with stack trace
|
||||
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hub := sentry.GetHubFromContext(ctx)
|
||||
if hub == nil {
|
||||
hub = s.hub
|
||||
}
|
||||
|
||||
event := sentry.NewEvent()
|
||||
event.Level = sentry.LevelError
|
||||
event.Message = fmt.Sprintf("Panic: %v", recovered)
|
||||
event.Exception = []sentry.Exception{
|
||||
{
|
||||
Value: fmt.Sprintf("%v", recovered),
|
||||
Type: "panic",
|
||||
},
|
||||
}
|
||||
|
||||
if extra != nil {
|
||||
event.Extra = extra
|
||||
}
|
||||
|
||||
if stackTrace != nil {
|
||||
event.Extra["stack_trace"] = string(stackTrace)
|
||||
}
|
||||
|
||||
hub.CaptureEvent(event)
|
||||
}
|
||||
|
||||
// Flush waits for all events to be sent (useful for graceful shutdown)
|
||||
func (s *SentryProvider) Flush(timeout int) bool {
|
||||
return sentry.Flush(time.Duration(timeout) * time.Second)
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (s *SentryProvider) Close() error {
|
||||
sentry.Flush(2 * time.Second)
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertSeverity converts our Severity to Sentry's Level
|
||||
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
|
||||
switch severity {
|
||||
case SeverityError:
|
||||
return sentry.LevelError
|
||||
case SeverityWarning:
|
||||
return sentry.LevelWarning
|
||||
case SeverityInfo:
|
||||
return sentry.LevelInfo
|
||||
case SeverityDebug:
|
||||
return sentry.LevelDebug
|
||||
default:
|
||||
return sentry.LevelError
|
||||
}
|
||||
}
|
||||
353
pkg/eventbroker/IMPLEMENTATION_PLAN.md
Normal file
353
pkg/eventbroker/IMPLEMENTATION_PLAN.md
Normal file
@@ -0,0 +1,353 @@
|
||||
# Event Broker System Implementation Plan
|
||||
|
||||
## Overview
|
||||
Implement a comprehensive event handler/broker system for ResolveSpec that follows existing architectural patterns (Provider interface, Hook system, Config management, Graceful shutdown).
|
||||
|
||||
## Requirements Met
|
||||
- ✅ Events with sources (database, websocket, frontend, system)
|
||||
- ✅ Event statuses (pending, processing, completed, failed)
|
||||
- ✅ Timestamps, JSON payloads, user IDs, session IDs
|
||||
- ✅ Program instance IDs for tracking server instances
|
||||
- ✅ Both sync and async processing modes
|
||||
- ✅ Multiple provider backends (in-memory, Redis, NATS, database)
|
||||
- ✅ Cross-instance pub/sub support
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
**Event Structure** (with full metadata):
|
||||
```go
|
||||
type Event struct {
|
||||
ID string // UUID
|
||||
Source EventSource // database, websocket, system, frontend
|
||||
Type string // Pattern: schema.entity.operation
|
||||
Status EventStatus // pending, processing, completed, failed
|
||||
Payload json.RawMessage // JSON payload
|
||||
UserID int
|
||||
SessionID string
|
||||
InstanceID string // Server instance identifier
|
||||
Schema string
|
||||
Entity string
|
||||
Operation string // create, update, delete, read
|
||||
CreatedAt time.Time
|
||||
ProcessedAt *time.Time
|
||||
CompletedAt *time.Time
|
||||
Error string
|
||||
Metadata map[string]interface{}
|
||||
RetryCount int
|
||||
}
|
||||
```
|
||||
|
||||
**Provider Pattern** (like cache.Provider):
|
||||
```go
|
||||
type Provider interface {
|
||||
Store(ctx context.Context, event *Event) error
|
||||
Get(ctx context.Context, id string) (*Event, error)
|
||||
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||
UpdateStatus(ctx context.Context, id string, status EventStatus, error string) error
|
||||
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
Close() error
|
||||
Stats(ctx context.Context) (*ProviderStats, error)
|
||||
}
|
||||
```
|
||||
|
||||
**Broker Interface**:
|
||||
```go
|
||||
type Broker interface {
|
||||
Publish(ctx context.Context, event *Event) error // Mode-dependent
|
||||
PublishSync(ctx context.Context, event *Event) error // Blocks
|
||||
PublishAsync(ctx context.Context, event *Event) error // Non-blocking
|
||||
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||
Unsubscribe(id SubscriptionID) error
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
Stats(ctx context.Context) (*BrokerStats, error)
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Phase 1: Core Foundation (Files: 1-5)
|
||||
|
||||
**1. Create `pkg/eventbroker/event.go`**
|
||||
- Event struct with all required fields (status, timestamps, user, instance ID, etc.)
|
||||
- EventSource enum (database, websocket, frontend, system, internal)
|
||||
- EventStatus enum (pending, processing, completed, failed)
|
||||
- Helper: `EventType(schema, entity, operation string) string`
|
||||
- Helper: `NewEvent()` constructor with UUID generation
|
||||
|
||||
**2. Create `pkg/eventbroker/provider.go`**
|
||||
- Provider interface definition
|
||||
- EventFilter struct for queries
|
||||
- ProviderStats struct
|
||||
|
||||
**3. Create `pkg/eventbroker/handler.go`**
|
||||
- EventHandler interface
|
||||
- EventHandlerFunc adapter type
|
||||
|
||||
**4. Create `pkg/eventbroker/broker.go`**
|
||||
- Broker interface definition
|
||||
- EventBroker struct implementation
|
||||
- ProcessingMode enum (sync, async)
|
||||
- Options struct with functional options (WithProvider, WithMode, WithWorkerCount, etc.)
|
||||
- NewBroker() constructor
|
||||
- Sync processing implementation
|
||||
|
||||
**5. Create `pkg/eventbroker/subscription.go`**
|
||||
- Pattern matching using glob syntax (e.g., "public.users.*", "*.*.create")
|
||||
- subscriptionManager struct
|
||||
- SubscriptionID type
|
||||
- Subscribe/Unsubscribe logic
|
||||
|
||||
### Phase 2: Configuration & Integration (Files: 6-8)
|
||||
|
||||
**6. Create `pkg/eventbroker/config.go`**
|
||||
- EventBrokerConfig struct
|
||||
- RedisConfig, NATSConfig, DatabaseConfig structs
|
||||
- RetryPolicyConfig struct
|
||||
|
||||
**7. Update `pkg/config/config.go`**
|
||||
- Add `EventBroker EventBrokerConfig` field to Config struct
|
||||
|
||||
**8. Update `pkg/config/manager.go`**
|
||||
- Add event broker defaults to `setDefaults()`:
|
||||
```go
|
||||
v.SetDefault("event_broker.enabled", false)
|
||||
v.SetDefault("event_broker.provider", "memory")
|
||||
v.SetDefault("event_broker.mode", "async")
|
||||
v.SetDefault("event_broker.worker_count", 10)
|
||||
v.SetDefault("event_broker.buffer_size", 1000)
|
||||
```
|
||||
|
||||
### Phase 3: Memory Provider (Files: 9)
|
||||
|
||||
**9. Create `pkg/eventbroker/provider_memory.go`**
|
||||
- MemoryProvider struct with mutex-protected map
|
||||
- In-memory event storage
|
||||
- Pattern matching for subscriptions
|
||||
- Channel-based streaming for real-time events
|
||||
- LRU eviction when max size reached
|
||||
- Cleanup goroutine for old completed events
|
||||
- **Note**: Single-instance only (no cross-instance pub/sub)
|
||||
|
||||
### Phase 4: Async Processing (Update File: 4)
|
||||
|
||||
**10. Update `pkg/eventbroker/broker.go`** (add async support)
|
||||
- workerPool struct with configurable worker count
|
||||
- Buffered channel for event queue
|
||||
- Worker goroutines that process events
|
||||
- PublishAsync() queues to channel
|
||||
- Graceful shutdown: stop accepting events, drain queue, wait for workers
|
||||
- Retry logic with exponential backoff
|
||||
|
||||
### Phase 5: Hook Integration (Files: 11)
|
||||
|
||||
**11. Create `pkg/eventbroker/hooks.go`**
|
||||
- `RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry)`
|
||||
- Registers AfterCreate, AfterUpdate, AfterDelete, AfterRead hooks
|
||||
- Extracts UserContext from hook context
|
||||
- Creates Event with proper metadata
|
||||
- Calls `broker.PublishAsync()` to not block CRUD operations
|
||||
|
||||
### Phase 6: Global Singleton & Factory (Files: 12-13)
|
||||
|
||||
**12. Create `pkg/eventbroker/eventbroker.go`**
|
||||
- Global `defaultBroker` variable
|
||||
- `Initialize(config *config.Config) error` - creates broker from config
|
||||
- `SetDefaultBroker(broker Broker)`
|
||||
- `GetDefaultBroker() Broker`
|
||||
- Helper functions: `Publish()`, `PublishAsync()`, `PublishSync()`, `Subscribe()`
|
||||
- `RegisterShutdown(broker Broker)` - registers with server.RegisterShutdownCallback()
|
||||
|
||||
**13. Create `pkg/eventbroker/factory.go`**
|
||||
- `NewProviderFromConfig(config EventBrokerConfig) (Provider, error)`
|
||||
- Provider selection logic (memory, redis, nats, database)
|
||||
- Returns appropriate provider based on config
|
||||
|
||||
### Phase 7: Redis Provider (Files: 14)
|
||||
|
||||
**14. Create `pkg/eventbroker/provider_redis.go`**
|
||||
- RedisProvider using Redis Streams (XADD, XREAD, XGROUP)
|
||||
- Consumer group for distributed processing
|
||||
- Cross-instance pub/sub support
|
||||
- Stream(pattern) subscribes to consumer group
|
||||
- Publish() uses XADD to append to stream
|
||||
- Graceful shutdown: acknowledge pending messages
|
||||
|
||||
**Dependencies**: `github.com/redis/go-redis/v9`
|
||||
|
||||
### Phase 8: NATS Provider (Files: 15)
|
||||
|
||||
**15. Create `pkg/eventbroker/provider_nats.go`**
|
||||
- NATSProvider using NATS JetStream
|
||||
- Subject-based routing: `events.{source}.{type}`
|
||||
- Wildcard subscriptions support
|
||||
- Durable consumers for replay
|
||||
- At-least-once delivery semantics
|
||||
|
||||
**Dependencies**: `github.com/nats-io/nats.go`
|
||||
|
||||
### Phase 9: Database Provider (Files: 16)
|
||||
|
||||
**16. Create `pkg/eventbroker/provider_database.go`**
|
||||
- DatabaseProvider using `common.Database` interface
|
||||
- Table schema creation (events table with indexes)
|
||||
- Polling-based event consumption (configurable interval)
|
||||
- Full SQL query support via List(filter)
|
||||
- Transaction support for atomic operations
|
||||
- Good for audit trails and debugging
|
||||
|
||||
### Phase 10: Metrics Integration (Files: 17)
|
||||
|
||||
**17. Create `pkg/eventbroker/metrics.go`**
|
||||
- Integrate with existing `metrics.Provider`
|
||||
- Record metrics:
|
||||
- `eventbroker_events_published_total{source, type}`
|
||||
- `eventbroker_events_processed_total{source, type, status}`
|
||||
- `eventbroker_event_processing_duration_seconds{source, type}`
|
||||
- `eventbroker_queue_size`
|
||||
- `eventbroker_workers_active`
|
||||
|
||||
**18. Update `pkg/metrics/interfaces.go`**
|
||||
- Add methods to Provider interface:
|
||||
```go
|
||||
RecordEventPublished(source, eventType string)
|
||||
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||
UpdateEventQueueSize(size int64)
|
||||
```
|
||||
|
||||
### Phase 11: Testing & Examples (Files: 19-20)
|
||||
|
||||
**19. Create `pkg/eventbroker/eventbroker_test.go`**
|
||||
- Unit tests for Event marshaling
|
||||
- Pattern matching tests
|
||||
- MemoryProvider tests
|
||||
- Sync vs async mode tests
|
||||
- Concurrent publish/subscribe tests
|
||||
- Graceful shutdown tests
|
||||
|
||||
**20. Create `pkg/eventbroker/example_usage.go`**
|
||||
- Basic publish example
|
||||
- Subscribe with patterns example
|
||||
- Hook integration example
|
||||
- Multiple handlers example
|
||||
- Error handling example
|
||||
|
||||
## Integration Points
|
||||
|
||||
### Hook System Integration
|
||||
```go
|
||||
// In application initialization (e.g., main.go)
|
||||
eventbroker.RegisterCRUDHooks(broker, handler.Hooks())
|
||||
```
|
||||
|
||||
This automatically publishes events for all CRUD operations:
|
||||
- `schema.entity.create` after inserts
|
||||
- `schema.entity.update` after updates
|
||||
- `schema.entity.delete` after deletes
|
||||
- `schema.entity.read` after reads
|
||||
|
||||
### Shutdown Integration
|
||||
```go
|
||||
// In application initialization
|
||||
eventbroker.RegisterShutdown(broker)
|
||||
```
|
||||
|
||||
Ensures event broker flushes pending events before shutdown.
|
||||
|
||||
### Configuration Example
|
||||
```yaml
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: redis # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Publishing Custom Events
|
||||
```go
|
||||
// WebSocket event
|
||||
event := &eventbroker.Event{
|
||||
Source: eventbroker.EventSourceWebSocket,
|
||||
Type: "chat.message",
|
||||
Payload: json.RawMessage(`{"room": "lobby", "msg": "Hello"}`),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
eventbroker.PublishAsync(ctx, event)
|
||||
```
|
||||
|
||||
### Subscribing to Events
|
||||
```go
|
||||
// Subscribe to all user creation events
|
||||
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("New user created: %s", event.Payload)
|
||||
// Send welcome email, update cache, etc.
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Subscribe to all events from database
|
||||
eventbroker.Subscribe("*", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
if event.Source == eventbroker.EventSourceDatabase {
|
||||
// Audit logging
|
||||
}
|
||||
return nil
|
||||
},
|
||||
))
|
||||
```
|
||||
|
||||
## Critical Files Reference
|
||||
|
||||
**Patterns to Follow**:
|
||||
- `pkg/cache/provider.go` - Provider interface pattern
|
||||
- `pkg/restheadspec/hooks.go` - Hook system integration
|
||||
- `pkg/config/manager.go` - Configuration pattern
|
||||
- `pkg/server/shutdown.go` - Shutdown callbacks
|
||||
|
||||
**Files to Modify**:
|
||||
- `pkg/config/config.go` - Add EventBroker field
|
||||
- `pkg/config/manager.go` - Add defaults
|
||||
- `pkg/metrics/interfaces.go` - Add event broker metrics
|
||||
|
||||
**New Package**:
|
||||
- `pkg/eventbroker/` (20 files total)
|
||||
|
||||
## Provider Feature Matrix
|
||||
|
||||
| Feature | Memory | Redis | NATS | Database |
|
||||
|---------|--------|-------|------|----------|
|
||||
| Persistence | ❌ | ✅ | ✅ | ✅ |
|
||||
| Cross-instance | ❌ | ✅ | ✅ | ✅ |
|
||||
| Real-time | ✅ | ✅ | ✅ | ⚠️ (polling) |
|
||||
| Query history | Limited | Limited | ✅ (replay) | ✅ (SQL) |
|
||||
| External deps | None | Redis | NATS | None |
|
||||
| Complexity | Low | Medium | Medium | Low |
|
||||
|
||||
## Implementation Order Priority
|
||||
|
||||
1. **Core + Memory Provider** (Phase 1-3) - Functional in-process event system
|
||||
2. **Async + Hooks** (Phase 4-5) - Non-blocking event dispatch integrated with CRUD
|
||||
3. **Config + Singleton** (Phase 6) - Easy initialization and usage
|
||||
4. **Redis Provider** (Phase 7) - Production-ready distributed events
|
||||
5. **Metrics** (Phase 10) - Observability
|
||||
6. **NATS/Database** (Phase 8-9) - Alternative backends
|
||||
7. **Tests + Examples** (Phase 11) - Documentation and reliability
|
||||
|
||||
## Next Steps
|
||||
|
||||
After approval, implement in order of phases. Each phase builds on previous phases and can be tested independently.
|
||||
347
pkg/eventbroker/README.md
Normal file
347
pkg/eventbroker/README.md
Normal file
@@ -0,0 +1,347 @@
|
||||
# Event Broker System
|
||||
|
||||
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
|
||||
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
|
||||
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
|
||||
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
|
||||
- **Pattern Matching**: Subscribe to events using glob-style patterns
|
||||
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
|
||||
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
|
||||
- **Retry Logic**: Configurable retry policy with exponential backoff
|
||||
- **Metrics**: Prometheus-compatible metrics for monitoring
|
||||
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configuration
|
||||
|
||||
Add to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: memory # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
```
|
||||
|
||||
### 2. Initialize
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
cfg, _ := cfgMgr.GetConfig()
|
||||
|
||||
// Initialize event broker
|
||||
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Subscribe to Events
|
||||
|
||||
```go
|
||||
// Subscribe to specific events
|
||||
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("New user created: %s", event.Payload)
|
||||
// Send welcome email, update cache, etc.
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Subscribe with patterns
|
||||
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
```
|
||||
|
||||
### 4. Publish Events
|
||||
|
||||
```go
|
||||
// Create and publish an event
|
||||
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
|
||||
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
|
||||
event.UserID = 123
|
||||
event.SessionID = "session-456"
|
||||
event.Schema = "public"
|
||||
event.Entity = "users"
|
||||
event.Operation = "update"
|
||||
|
||||
event.SetPayload(map[string]interface{}{
|
||||
"id": 123,
|
||||
"name": "John Doe",
|
||||
})
|
||||
|
||||
// Async (non-blocking)
|
||||
eventbroker.PublishAsync(ctx, event)
|
||||
|
||||
// Sync (blocking)
|
||||
eventbroker.PublishSync(ctx, event)
|
||||
```
|
||||
|
||||
## Automatic CRUD Event Capture
|
||||
|
||||
Automatically capture database CRUD operations:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
)
|
||||
|
||||
func setupHooks(handler *restheadspec.Handler) {
|
||||
broker := eventbroker.GetDefaultBroker()
|
||||
|
||||
// Configure which operations to capture
|
||||
config := eventbroker.DefaultCRUDHookConfig()
|
||||
config.EnableRead = false // Disable read events for performance
|
||||
|
||||
// Register hooks
|
||||
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
|
||||
|
||||
// Now all create/update/delete operations automatically publish events!
|
||||
}
|
||||
```
|
||||
|
||||
## Event Structure
|
||||
|
||||
Every event contains:
|
||||
|
||||
```go
|
||||
type Event struct {
|
||||
ID string // UUID
|
||||
Source EventSource // database, websocket, system, frontend, internal
|
||||
Type string // Pattern: schema.entity.operation
|
||||
Status EventStatus // pending, processing, completed, failed
|
||||
Payload json.RawMessage // JSON payload
|
||||
UserID int // User who triggered the event
|
||||
SessionID string // Session identifier
|
||||
InstanceID string // Server instance identifier
|
||||
Schema string // Database schema
|
||||
Entity string // Database entity/table
|
||||
Operation string // create, update, delete, read
|
||||
CreatedAt time.Time // When event was created
|
||||
ProcessedAt *time.Time // When processing started
|
||||
CompletedAt *time.Time // When processing completed
|
||||
Error string // Error message if failed
|
||||
Metadata map[string]interface{} // Additional context
|
||||
RetryCount int // Number of retry attempts
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern Matching
|
||||
|
||||
Subscribe to events using glob-style patterns:
|
||||
|
||||
| Pattern | Matches | Example |
|
||||
|---------|---------|---------|
|
||||
| `*` | All events | Any event |
|
||||
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
|
||||
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
|
||||
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
|
||||
| `public.users.create` | Exact match | Only `public.users.create` |
|
||||
|
||||
## Providers
|
||||
|
||||
### Memory Provider (Default)
|
||||
|
||||
Best for: Development, single-instance deployments
|
||||
|
||||
- **Pros**: Fast, no dependencies, simple
|
||||
- **Cons**: Events lost on restart, single-instance only
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: memory
|
||||
```
|
||||
|
||||
### Redis Provider
|
||||
|
||||
Best for: Production, multi-instance deployments
|
||||
|
||||
- **Pros**: Persistent, cross-instance pub/sub, reliable, Redis Streams support
|
||||
- **Cons**: Requires Redis server
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: redis
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
max_len: 10000
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
```
|
||||
|
||||
### NATS Provider
|
||||
|
||||
Best for: High-performance, low-latency requirements
|
||||
|
||||
- **Pros**: Very fast, built-in clustering, durable, JetStream support
|
||||
- **Cons**: Requires NATS server
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: nats
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "RESOLVESPEC_EVENTS"
|
||||
storage: "file" # or "memory"
|
||||
max_age: "24h"
|
||||
```
|
||||
|
||||
### Database Provider
|
||||
|
||||
Best for: Audit trails, event replay, SQL queries
|
||||
|
||||
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
|
||||
- **Cons**: Slower than Redis/NATS, requires database connection
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: database
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "resolvespec_events"
|
||||
poll_interval: "1s"
|
||||
```
|
||||
|
||||
## Processing Modes
|
||||
|
||||
### Async Mode (Recommended)
|
||||
|
||||
Events are queued and processed by worker pool:
|
||||
|
||||
- Non-blocking event publishing
|
||||
- Configurable worker count
|
||||
- Better throughput
|
||||
- Events may be processed out of order
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
mode: async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
```
|
||||
|
||||
### Sync Mode
|
||||
|
||||
Events are processed immediately:
|
||||
|
||||
- Blocking event publishing
|
||||
- Guaranteed ordering
|
||||
- Immediate error feedback
|
||||
- Lower throughput
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
mode: sync
|
||||
```
|
||||
|
||||
## Retry Policy
|
||||
|
||||
Configure automatic retries for failed handlers:
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 30s
|
||||
backoff_factor: 2.0 # Exponential backoff
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
The event broker exposes Prometheus metrics:
|
||||
|
||||
- `eventbroker_events_published_total{source, type}` - Total events published
|
||||
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
|
||||
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
|
||||
- `eventbroker_queue_size` - Current queue size (async mode)
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Async Mode**: For better performance, use async mode in production
|
||||
2. **Disable Read Events**: Read events can be high volume; disable if not needed
|
||||
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
|
||||
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
|
||||
5. **Idempotency**: Make handlers idempotent as events may be retried
|
||||
6. **Payload Size**: Keep payloads reasonable; avoid large objects
|
||||
7. **Monitoring**: Monitor metrics to detect issues early
|
||||
|
||||
## Examples
|
||||
|
||||
See `example_usage.go` for comprehensive examples including:
|
||||
- Basic event publishing and subscription
|
||||
- Hook integration
|
||||
- Error handling
|
||||
- Configuration
|
||||
- Pattern matching
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ Application │
|
||||
└────────┬────────┘
|
||||
│
|
||||
├─ Publish Events
|
||||
│
|
||||
┌────────▼────────┐ ┌──────────────┐
|
||||
│ Event Broker │◄────►│ Subscribers │
|
||||
└────────┬────────┘ └──────────────┘
|
||||
│
|
||||
├─ Store Events
|
||||
│
|
||||
┌────────▼────────┐
|
||||
│ Provider │
|
||||
│ (Memory/Redis │
|
||||
│ /NATS/DB) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Implemented Features
|
||||
|
||||
- [x] Memory Provider (in-process, single-instance)
|
||||
- [x] Redis Streams Provider (distributed, persistent)
|
||||
- [x] NATS JetStream Provider (distributed, high-performance)
|
||||
- [x] Database Provider with PostgreSQL NOTIFY (SQL-queryable, audit-friendly)
|
||||
- [x] Sync and Async processing modes
|
||||
- [x] Pattern-based subscriptions
|
||||
- [x] Hook integration for automatic CRUD events
|
||||
- [x] Retry policy with exponential backoff
|
||||
- [x] Graceful shutdown
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Event replay functionality from specific timestamp
|
||||
- [ ] Dead letter queue for failed events
|
||||
- [ ] Event filtering at provider level for performance
|
||||
- [ ] Batch publishing support
|
||||
- [ ] Event compression for large payloads
|
||||
- [ ] Schema versioning and migration
|
||||
- [ ] Event streaming to external systems (Kafka, RabbitMQ)
|
||||
- [ ] Event aggregation and analytics
|
||||
453
pkg/eventbroker/broker.go
Normal file
453
pkg/eventbroker/broker.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Broker is the main interface for event publishing and subscription
|
||||
type Broker interface {
|
||||
// Publish publishes an event (mode-dependent: sync or async)
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
|
||||
// PublishSync publishes an event synchronously (blocks until all handlers complete)
|
||||
PublishSync(ctx context.Context, event *Event) error
|
||||
|
||||
// PublishAsync publishes an event asynchronously (returns immediately)
|
||||
PublishAsync(ctx context.Context, event *Event) error
|
||||
|
||||
// Subscribe registers a handler for events matching the pattern
|
||||
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
Unsubscribe(id SubscriptionID) error
|
||||
|
||||
// Start starts the broker (begins processing events)
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop stops the broker gracefully (flushes pending events)
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Stats returns broker statistics
|
||||
Stats(ctx context.Context) (*BrokerStats, error)
|
||||
|
||||
// InstanceID returns the instance ID of this broker
|
||||
InstanceID() string
|
||||
}
|
||||
|
||||
// ProcessingMode determines how events are processed
|
||||
type ProcessingMode string
|
||||
|
||||
const (
|
||||
ProcessingModeSync ProcessingMode = "sync"
|
||||
ProcessingModeAsync ProcessingMode = "async"
|
||||
)
|
||||
|
||||
// BrokerStats contains broker statistics
|
||||
type BrokerStats struct {
|
||||
InstanceID string `json:"instance_id"`
|
||||
Mode ProcessingMode `json:"mode"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
TotalPublished int64 `json:"total_published"`
|
||||
TotalProcessed int64 `json:"total_processed"`
|
||||
TotalFailed int64 `json:"total_failed"`
|
||||
ActiveSubscribers int `json:"active_subscribers"`
|
||||
QueueSize int `json:"queue_size,omitempty"` // For async mode
|
||||
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
|
||||
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
|
||||
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroker implements the Broker interface
|
||||
type EventBroker struct {
|
||||
provider Provider
|
||||
subscriptions *subscriptionManager
|
||||
mode ProcessingMode
|
||||
instanceID string
|
||||
retryPolicy *RetryPolicy
|
||||
|
||||
// Async mode fields (initialized in Phase 4)
|
||||
workerPool *workerPool
|
||||
|
||||
// Runtime state
|
||||
isRunning atomic.Bool
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Statistics
|
||||
statsPublished atomic.Int64
|
||||
statsProcessed atomic.Int64
|
||||
statsFailed atomic.Int64
|
||||
}
|
||||
|
||||
// RetryPolicy defines how failed events should be retried
|
||||
type RetryPolicy struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy returns a sensible default retry policy
|
||||
func DefaultRetryPolicy() *RetryPolicy {
|
||||
return &RetryPolicy{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
// Options for creating a new broker
|
||||
type Options struct {
|
||||
Provider Provider
|
||||
Mode ProcessingMode
|
||||
WorkerCount int // For async mode
|
||||
BufferSize int // For async mode
|
||||
RetryPolicy *RetryPolicy
|
||||
InstanceID string
|
||||
}
|
||||
|
||||
// NewBroker creates a new event broker with the given options
|
||||
func NewBroker(opts Options) (*EventBroker, error) {
|
||||
if opts.Provider == nil {
|
||||
return nil, fmt.Errorf("provider is required")
|
||||
}
|
||||
if opts.InstanceID == "" {
|
||||
return nil, fmt.Errorf("instance ID is required")
|
||||
}
|
||||
if opts.Mode == "" {
|
||||
opts.Mode = ProcessingModeAsync // Default to async
|
||||
}
|
||||
if opts.RetryPolicy == nil {
|
||||
opts.RetryPolicy = DefaultRetryPolicy()
|
||||
}
|
||||
|
||||
broker := &EventBroker{
|
||||
provider: opts.Provider,
|
||||
subscriptions: newSubscriptionManager(),
|
||||
mode: opts.Mode,
|
||||
instanceID: opts.InstanceID,
|
||||
retryPolicy: opts.RetryPolicy,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Worker pool will be initialized in Phase 4 for async mode
|
||||
if opts.Mode == ProcessingModeAsync {
|
||||
if opts.WorkerCount == 0 {
|
||||
opts.WorkerCount = 10 // Default
|
||||
}
|
||||
if opts.BufferSize == 0 {
|
||||
opts.BufferSize = 1000 // Default
|
||||
}
|
||||
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
|
||||
}
|
||||
|
||||
return broker, nil
|
||||
}
|
||||
|
||||
// Functional option pattern helpers
|
||||
func WithProvider(p Provider) func(*Options) {
|
||||
return func(o *Options) { o.Provider = p }
|
||||
}
|
||||
|
||||
func WithMode(m ProcessingMode) func(*Options) {
|
||||
return func(o *Options) { o.Mode = m }
|
||||
}
|
||||
|
||||
func WithWorkerCount(count int) func(*Options) {
|
||||
return func(o *Options) { o.WorkerCount = count }
|
||||
}
|
||||
|
||||
func WithBufferSize(size int) func(*Options) {
|
||||
return func(o *Options) { o.BufferSize = size }
|
||||
}
|
||||
|
||||
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
|
||||
return func(o *Options) { o.RetryPolicy = policy }
|
||||
}
|
||||
|
||||
func WithInstanceID(id string) func(*Options) {
|
||||
return func(o *Options) { o.InstanceID = id }
|
||||
}
|
||||
|
||||
// Start starts the broker
|
||||
func (b *EventBroker) Start(ctx context.Context) error {
|
||||
if b.isRunning.Load() {
|
||||
return fmt.Errorf("broker already running")
|
||||
}
|
||||
|
||||
b.isRunning.Store(true)
|
||||
|
||||
// Start worker pool for async mode
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
b.workerPool.Start()
|
||||
}
|
||||
|
||||
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the broker gracefully
|
||||
func (b *EventBroker) Stop(ctx context.Context) error {
|
||||
var stopErr error
|
||||
|
||||
b.stopOnce.Do(func() {
|
||||
logger.Info("Stopping event broker...")
|
||||
|
||||
// Mark as not running
|
||||
b.isRunning.Store(false)
|
||||
|
||||
// Close the stop channel
|
||||
close(b.stopCh)
|
||||
|
||||
// Stop worker pool for async mode
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
if err := b.workerPool.Stop(ctx); err != nil {
|
||||
logger.Error("Error stopping worker pool: %v", err)
|
||||
stopErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
b.wg.Wait()
|
||||
|
||||
// Close provider
|
||||
if err := b.provider.Close(); err != nil {
|
||||
logger.Error("Error closing provider: %v", err)
|
||||
if stopErr == nil {
|
||||
stopErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Event broker stopped")
|
||||
})
|
||||
|
||||
return stopErr
|
||||
}
|
||||
|
||||
// Publish publishes an event based on the broker's mode
|
||||
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
|
||||
if b.mode == ProcessingModeSync {
|
||||
return b.PublishSync(ctx, event)
|
||||
}
|
||||
return b.PublishAsync(ctx, event)
|
||||
}
|
||||
|
||||
// PublishSync publishes an event synchronously
|
||||
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
|
||||
if !b.isRunning.Load() {
|
||||
return fmt.Errorf("broker is not running")
|
||||
}
|
||||
|
||||
// Validate event
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid event: %w", err)
|
||||
}
|
||||
|
||||
// Store event in provider
|
||||
if err := b.provider.Publish(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
b.statsPublished.Add(1)
|
||||
|
||||
// Record metrics
|
||||
recordEventPublished(event)
|
||||
|
||||
// Process event synchronously
|
||||
if err := b.processEvent(ctx, event); err != nil {
|
||||
logger.Error("Failed to process event %s: %v", event.ID, err)
|
||||
b.statsFailed.Add(1)
|
||||
return err
|
||||
}
|
||||
|
||||
b.statsProcessed.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishAsync publishes an event asynchronously
|
||||
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
|
||||
if !b.isRunning.Load() {
|
||||
return fmt.Errorf("broker is not running")
|
||||
}
|
||||
|
||||
// Validate event
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid event: %w", err)
|
||||
}
|
||||
|
||||
// Store event in provider
|
||||
if err := b.provider.Publish(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
b.statsPublished.Add(1)
|
||||
|
||||
// Record metrics
|
||||
recordEventPublished(event)
|
||||
|
||||
// Queue for async processing
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
// Update queue size metrics
|
||||
updateQueueSize(int64(b.workerPool.QueueSize()))
|
||||
return b.workerPool.Submit(ctx, event)
|
||||
}
|
||||
|
||||
// Fallback to sync if async not configured
|
||||
return b.processEvent(ctx, event)
|
||||
}
|
||||
|
||||
// Subscribe adds a subscription for events matching the pattern
|
||||
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
return b.subscriptions.Subscribe(pattern, handler)
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
|
||||
return b.subscriptions.Unsubscribe(id)
|
||||
}
|
||||
|
||||
// processEvent processes an event by calling all matching handlers
|
||||
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// Get all handlers matching this event type
|
||||
handlers := b.subscriptions.GetMatching(event.Type)
|
||||
|
||||
if len(handlers) == 0 {
|
||||
logger.Debug("No handlers for event type: %s", event.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
|
||||
|
||||
// Mark event as processing
|
||||
event.MarkProcessing()
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
var lastErr error
|
||||
for i, handler := range handlers {
|
||||
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
|
||||
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
|
||||
lastErr = err
|
||||
// Continue processing other handlers
|
||||
}
|
||||
}
|
||||
|
||||
// Update final status
|
||||
if lastErr != nil {
|
||||
event.MarkFailed(lastErr)
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
recordEventProcessed(event, time.Since(startTime))
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
event.MarkCompleted()
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
recordEventProcessed(event, time.Since(startTime))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeHandlerWithRetry executes a handler with retry logic
|
||||
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Calculate backoff delay
|
||||
delay := b.calculateBackoff(attempt)
|
||||
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
|
||||
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
event.IncrementRetry()
|
||||
}
|
||||
|
||||
// Execute handler
|
||||
if err := handler.Handle(ctx, event); err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Success
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// calculateBackoff calculates the backoff delay for a retry attempt
|
||||
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
|
||||
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
|
||||
if delay > float64(b.retryPolicy.MaxDelay) {
|
||||
delay = float64(b.retryPolicy.MaxDelay)
|
||||
}
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// pow is a simple integer power function
|
||||
func pow(base float64, exp float64) float64 {
|
||||
result := 1.0
|
||||
for i := 0.0; i < exp; i++ {
|
||||
result *= base
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Stats returns broker statistics
|
||||
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
|
||||
providerStats, err := b.provider.Stats(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get provider stats: %v", err)
|
||||
}
|
||||
|
||||
stats := &BrokerStats{
|
||||
InstanceID: b.instanceID,
|
||||
Mode: b.mode,
|
||||
IsRunning: b.isRunning.Load(),
|
||||
TotalPublished: b.statsPublished.Load(),
|
||||
TotalProcessed: b.statsProcessed.Load(),
|
||||
TotalFailed: b.statsFailed.Load(),
|
||||
ActiveSubscribers: b.subscriptions.Count(),
|
||||
ProviderStats: providerStats,
|
||||
}
|
||||
|
||||
// Add async-specific stats
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
stats.QueueSize = b.workerPool.QueueSize()
|
||||
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// InstanceID returns the instance ID
|
||||
func (b *EventBroker) InstanceID() string {
|
||||
return b.instanceID
|
||||
}
|
||||
524
pkg/eventbroker/broker_test.go
Normal file
524
pkg/eventbroker/broker_test.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBroker(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 1000,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts Options
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid options",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider",
|
||||
opts: Options{
|
||||
InstanceID: "test-instance",
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing instance ID",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "async mode with defaults",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
broker, err := NewBroker(tt.opts)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
if err == nil && broker == nil {
|
||||
t.Error("Expected non-nil broker")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerStartStop(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, err := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create broker: %v", err)
|
||||
}
|
||||
|
||||
// Test Start
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to start broker: %v", err)
|
||||
}
|
||||
|
||||
// Test double start (should fail)
|
||||
if err := broker.Start(context.Background()); err == nil {
|
||||
t.Error("Expected error on double start")
|
||||
}
|
||||
|
||||
// Test Stop
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to stop broker: %v", err)
|
||||
}
|
||||
|
||||
// Test double stop (should not fail)
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Error("Double stop should not fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishSync(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe to events
|
||||
called := false
|
||||
var receivedEvent *Event
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
receivedEvent = event
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.PublishSync(context.Background(), event)
|
||||
if err != nil {
|
||||
t.Fatalf("PublishSync failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler was called
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
if receivedEvent == nil || receivedEvent.ID != event.ID {
|
||||
t.Error("Expected to receive the published event")
|
||||
}
|
||||
|
||||
// Verify event status
|
||||
if event.Status != EventStatusCompleted {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishAsync(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 2,
|
||||
BufferSize: 10,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe to events
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
if err := broker.PublishAsync(context.Background(), event); err != nil {
|
||||
t.Fatalf("PublishAsync failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for events to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if callCount.Load() != 5 {
|
||||
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishBeforeStart(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.Publish(context.Background(), event)
|
||||
if err == nil {
|
||||
t.Error("Expected error when publishing before start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerHandlerError(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
RetryPolicy: &RetryPolicy{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
},
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe with failing handler
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return errors.New("handler error")
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.PublishSync(context.Background(), event)
|
||||
|
||||
// Should fail after retries
|
||||
if err == nil {
|
||||
t.Error("Expected error from handler")
|
||||
}
|
||||
|
||||
// Should have been called MaxRetries+1 times (initial + retries)
|
||||
if callCount.Load() != 3 {
|
||||
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
|
||||
}
|
||||
|
||||
// Event should be marked as failed
|
||||
if event.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerMultipleHandlers(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe multiple handlers
|
||||
var called1, called2, called3 bool
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called1 = true
|
||||
return nil
|
||||
}))
|
||||
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called2 = true
|
||||
return nil
|
||||
}))
|
||||
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called3 = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
|
||||
// All handlers should be called
|
||||
if !called1 || !called2 || !called3 {
|
||||
t.Error("Expected all handlers to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerUnsubscribe(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe
|
||||
called := false
|
||||
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Unsubscribe
|
||||
if err := broker.Unsubscribe(id); err != nil {
|
||||
t.Fatalf("Unsubscribe failed: %v", err)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
|
||||
// Handler should not be called
|
||||
if called {
|
||||
t.Error("Expected handler not to be called after unsubscribe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerStats(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 3; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats, err := broker.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.InstanceID != "test-instance" {
|
||||
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
|
||||
}
|
||||
if stats.TotalPublished != 3 {
|
||||
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
|
||||
}
|
||||
if stats.TotalProcessed != 3 {
|
||||
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
|
||||
}
|
||||
if stats.ActiveSubscribers != 1 {
|
||||
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerInstanceID(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "my-instance",
|
||||
})
|
||||
|
||||
if broker.InstanceID() != "my-instance" {
|
||||
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerConcurrentPublish(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 5,
|
||||
BufferSize: 100,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishAsync(context.Background(), event)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(200 * time.Millisecond) // Wait for async processing
|
||||
|
||||
if callCount.Load() != 50 {
|
||||
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerGracefulShutdown(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 2,
|
||||
BufferSize: 10,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
|
||||
var processedCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||
processedCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishAsync(context.Background(), event)
|
||||
}
|
||||
|
||||
// Stop broker (should wait for events to be processed)
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
// All events should be processed
|
||||
if processedCount.Load() != 5 {
|
||||
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerDefaultRetryPolicy(t *testing.T) {
|
||||
policy := DefaultRetryPolicy()
|
||||
|
||||
if policy.MaxRetries != 3 {
|
||||
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
|
||||
}
|
||||
if policy.InitialDelay != 1*time.Second {
|
||||
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
|
||||
}
|
||||
if policy.BackoffFactor != 2.0 {
|
||||
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerProcessingModes(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode ProcessingMode
|
||||
}{
|
||||
{"sync mode", ProcessingModeSync},
|
||||
{"async mode", ProcessingModeAsync},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: tt.mode,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
called := false
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.Publish(context.Background(), event)
|
||||
|
||||
if tt.mode == ProcessingModeAsync {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
pkg/eventbroker/event.go
Normal file
175
pkg/eventbroker/event.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// EventSource represents where an event originated from
|
||||
type EventSource string
|
||||
|
||||
const (
|
||||
EventSourceDatabase EventSource = "database"
|
||||
EventSourceWebSocket EventSource = "websocket"
|
||||
EventSourceFrontend EventSource = "frontend"
|
||||
EventSourceSystem EventSource = "system"
|
||||
EventSourceInternal EventSource = "internal"
|
||||
)
|
||||
|
||||
// EventStatus represents the current state of an event
|
||||
type EventStatus string
|
||||
|
||||
const (
|
||||
EventStatusPending EventStatus = "pending"
|
||||
EventStatusProcessing EventStatus = "processing"
|
||||
EventStatusCompleted EventStatus = "completed"
|
||||
EventStatusFailed EventStatus = "failed"
|
||||
)
|
||||
|
||||
// Event represents a single event in the system with complete metadata
|
||||
type Event struct {
|
||||
// Identification
|
||||
ID string `json:"id" db:"id"`
|
||||
|
||||
// Source & Classification
|
||||
Source EventSource `json:"source" db:"source"`
|
||||
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
|
||||
|
||||
// Status Tracking
|
||||
Status EventStatus `json:"status" db:"status"`
|
||||
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||
Error string `json:"error,omitempty" db:"error"`
|
||||
|
||||
// Payload
|
||||
Payload json.RawMessage `json:"payload" db:"payload"`
|
||||
|
||||
// Context Information
|
||||
UserID int `json:"user_id" db:"user_id"`
|
||||
SessionID string `json:"session_id" db:"session_id"`
|
||||
InstanceID string `json:"instance_id" db:"instance_id"`
|
||||
|
||||
// Database Context
|
||||
Schema string `json:"schema" db:"schema"`
|
||||
Entity string `json:"entity" db:"entity"`
|
||||
Operation string `json:"operation" db:"operation"` // create, update, delete, read
|
||||
|
||||
// Timestamps
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
|
||||
// Extensibility
|
||||
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
|
||||
}
|
||||
|
||||
// NewEvent creates a new event with defaults
|
||||
func NewEvent(source EventSource, eventType string) *Event {
|
||||
return &Event{
|
||||
ID: uuid.New().String(),
|
||||
Source: source,
|
||||
Type: eventType,
|
||||
Status: EventStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
Metadata: make(map[string]interface{}),
|
||||
RetryCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// EventType generates a type string from schema, entity, and operation
|
||||
// Pattern: schema.entity.operation (e.g., "public.users.create")
|
||||
func EventType(schema, entity, operation string) string {
|
||||
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
|
||||
}
|
||||
|
||||
// MarkProcessing marks the event as being processed
|
||||
func (e *Event) MarkProcessing() {
|
||||
e.Status = EventStatusProcessing
|
||||
now := time.Now()
|
||||
e.ProcessedAt = &now
|
||||
}
|
||||
|
||||
// MarkCompleted marks the event as successfully completed
|
||||
func (e *Event) MarkCompleted() {
|
||||
e.Status = EventStatusCompleted
|
||||
now := time.Now()
|
||||
e.CompletedAt = &now
|
||||
}
|
||||
|
||||
// MarkFailed marks the event as failed with an error message
|
||||
func (e *Event) MarkFailed(err error) {
|
||||
e.Status = EventStatusFailed
|
||||
e.Error = err.Error()
|
||||
now := time.Now()
|
||||
e.CompletedAt = &now
|
||||
}
|
||||
|
||||
// IncrementRetry increments the retry counter
|
||||
func (e *Event) IncrementRetry() {
|
||||
e.RetryCount++
|
||||
}
|
||||
|
||||
// SetPayload sets the event payload from any value by marshaling to JSON
|
||||
func (e *Event) SetPayload(v interface{}) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
e.Payload = data
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayload unmarshals the payload into the provided value
|
||||
func (e *Event) GetPayload(v interface{}) error {
|
||||
if len(e.Payload) == 0 {
|
||||
return fmt.Errorf("payload is empty")
|
||||
}
|
||||
if err := json.Unmarshal(e.Payload, v); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal payload: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the event
|
||||
func (e *Event) Clone() *Event {
|
||||
clone := *e
|
||||
|
||||
// Deep copy metadata
|
||||
if e.Metadata != nil {
|
||||
clone.Metadata = make(map[string]interface{})
|
||||
for k, v := range e.Metadata {
|
||||
clone.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy timestamps
|
||||
if e.ProcessedAt != nil {
|
||||
t := *e.ProcessedAt
|
||||
clone.ProcessedAt = &t
|
||||
}
|
||||
if e.CompletedAt != nil {
|
||||
t := *e.CompletedAt
|
||||
clone.CompletedAt = &t
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// Validate performs basic validation on the event
|
||||
func (e *Event) Validate() error {
|
||||
if e.ID == "" {
|
||||
return fmt.Errorf("event ID is required")
|
||||
}
|
||||
if e.Source == "" {
|
||||
return fmt.Errorf("event source is required")
|
||||
}
|
||||
if e.Type == "" {
|
||||
return fmt.Errorf("event type is required")
|
||||
}
|
||||
if e.InstanceID == "" {
|
||||
return fmt.Errorf("instance ID is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
314
pkg/eventbroker/event_test.go
Normal file
314
pkg/eventbroker/event_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewEvent(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
if event.ID == "" {
|
||||
t.Error("Expected event ID to be generated")
|
||||
}
|
||||
if event.Source != EventSourceDatabase {
|
||||
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
|
||||
}
|
||||
if event.Type != "public.users.create" {
|
||||
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
|
||||
}
|
||||
if event.Status != EventStatusPending {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
|
||||
}
|
||||
if event.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
if event.Metadata == nil {
|
||||
t.Error("Expected Metadata to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventType(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
entity string
|
||||
operation string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "create", "public.users.create"},
|
||||
{"admin", "roles", "update", "admin.roles.update"},
|
||||
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := EventType(tt.schema, tt.entity, tt.operation)
|
||||
if result != tt.expected {
|
||||
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
|
||||
tt.schema, tt.entity, tt.operation, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
event *Event
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid event",
|
||||
event: func() *Event {
|
||||
e := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
e.InstanceID = "test-instance"
|
||||
return e
|
||||
}(),
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "missing ID",
|
||||
event: &Event{
|
||||
Source: EventSourceDatabase,
|
||||
Type: "public.users.create",
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing source",
|
||||
event: &Event{
|
||||
ID: "test-id",
|
||||
Type: "public.users.create",
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
event: &Event{
|
||||
ID: "test-id",
|
||||
Source: EventSourceDatabase,
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.event.Validate()
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetPayload(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
}
|
||||
|
||||
err := event.SetPayload(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("SetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
if event.Payload == nil {
|
||||
t.Fatal("Expected payload to be set")
|
||||
}
|
||||
|
||||
// Verify payload can be unmarshaled
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(event.Payload, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal payload: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "John Doe" {
|
||||
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventGetPayload(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"id": float64(1), // JSON unmarshals numbers as float64
|
||||
"name": "John Doe",
|
||||
}
|
||||
|
||||
if err := event.SetPayload(payload); err != nil {
|
||||
t.Fatalf("SetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := event.GetPayload(&result); err != nil {
|
||||
t.Fatalf("GetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "John Doe" {
|
||||
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkProcessing(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.MarkProcessing()
|
||||
|
||||
if event.Status != EventStatusProcessing {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
|
||||
}
|
||||
if event.ProcessedAt == nil {
|
||||
t.Error("Expected ProcessedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkCompleted(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.MarkCompleted()
|
||||
|
||||
if event.Status != EventStatusCompleted {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||
}
|
||||
if event.CompletedAt == nil {
|
||||
t.Error("Expected CompletedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkFailed(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
testErr := errors.New("test error")
|
||||
event.MarkFailed(testErr)
|
||||
|
||||
if event.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||
}
|
||||
if event.Error != "test error" {
|
||||
t.Errorf("Expected error %q, got %q", "test error", event.Error)
|
||||
}
|
||||
if event.CompletedAt == nil {
|
||||
t.Error("Expected CompletedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventIncrementRetry(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
initialCount := event.RetryCount
|
||||
event.IncrementRetry()
|
||||
|
||||
if event.RetryCount != initialCount+1 {
|
||||
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventJSONMarshaling(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.UserID = 123
|
||||
event.SessionID = "session-123"
|
||||
event.InstanceID = "instance-1"
|
||||
event.Schema = "public"
|
||||
event.Entity = "users"
|
||||
event.Operation = "create"
|
||||
event.SetPayload(map[string]interface{}{"name": "Test"})
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded Event
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal event: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.ID != event.ID {
|
||||
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
|
||||
}
|
||||
if decoded.Source != event.Source {
|
||||
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
|
||||
}
|
||||
if decoded.UserID != event.UserID {
|
||||
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStatusString(t *testing.T) {
|
||||
statuses := []EventStatus{
|
||||
EventStatusPending,
|
||||
EventStatusProcessing,
|
||||
EventStatusCompleted,
|
||||
EventStatusFailed,
|
||||
}
|
||||
|
||||
for _, status := range statuses {
|
||||
if string(status) == "" {
|
||||
t.Errorf("EventStatus %v has empty string representation", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSourceString(t *testing.T) {
|
||||
sources := []EventSource{
|
||||
EventSourceDatabase,
|
||||
EventSourceWebSocket,
|
||||
EventSourceFrontend,
|
||||
EventSourceSystem,
|
||||
EventSourceInternal,
|
||||
}
|
||||
|
||||
for _, source := range sources {
|
||||
if string(source) == "" {
|
||||
t.Errorf("EventSource %v has empty string representation", source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMetadata(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
// Test setting metadata
|
||||
event.Metadata["key1"] = "value1"
|
||||
event.Metadata["key2"] = 123
|
||||
|
||||
if event.Metadata["key1"] != "value1" {
|
||||
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
|
||||
}
|
||||
if event.Metadata["key2"] != 123 {
|
||||
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventTimestamps(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
createdAt := event.CreatedAt
|
||||
|
||||
// Wait a tiny bit to ensure timestamps differ
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
event.MarkProcessing()
|
||||
if event.ProcessedAt == nil {
|
||||
t.Fatal("ProcessedAt should be set")
|
||||
}
|
||||
if !event.ProcessedAt.After(createdAt) {
|
||||
t.Error("ProcessedAt should be after CreatedAt")
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
event.MarkCompleted()
|
||||
if event.CompletedAt == nil {
|
||||
t.Fatal("CompletedAt should be set")
|
||||
}
|
||||
if !event.CompletedAt.After(*event.ProcessedAt) {
|
||||
t.Error("CompletedAt should be after ProcessedAt")
|
||||
}
|
||||
}
|
||||
158
pkg/eventbroker/eventbroker.go
Normal file
158
pkg/eventbroker/eventbroker.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultBroker Broker
|
||||
brokerMu sync.RWMutex
|
||||
)
|
||||
|
||||
// Initialize initializes the global event broker from configuration
|
||||
func Initialize(cfg config.EventBrokerConfig) error {
|
||||
if !cfg.Enabled {
|
||||
logger.Info("Event broker is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create provider
|
||||
provider, err := NewProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create provider: %w", err)
|
||||
}
|
||||
|
||||
// Parse mode
|
||||
mode := ProcessingModeAsync
|
||||
if cfg.Mode == "sync" {
|
||||
mode = ProcessingModeSync
|
||||
}
|
||||
|
||||
// Convert retry policy
|
||||
retryPolicy := &RetryPolicy{
|
||||
MaxRetries: cfg.RetryPolicy.MaxRetries,
|
||||
InitialDelay: cfg.RetryPolicy.InitialDelay,
|
||||
MaxDelay: cfg.RetryPolicy.MaxDelay,
|
||||
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
|
||||
}
|
||||
if retryPolicy.MaxRetries == 0 {
|
||||
retryPolicy = DefaultRetryPolicy()
|
||||
}
|
||||
|
||||
// Create broker options
|
||||
opts := Options{
|
||||
Provider: provider,
|
||||
Mode: mode,
|
||||
WorkerCount: cfg.WorkerCount,
|
||||
BufferSize: cfg.BufferSize,
|
||||
RetryPolicy: retryPolicy,
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
}
|
||||
|
||||
// Create broker
|
||||
broker, err := NewBroker(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create broker: %w", err)
|
||||
}
|
||||
|
||||
// Start broker
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
return fmt.Errorf("failed to start broker: %w", err)
|
||||
}
|
||||
|
||||
// Set as default
|
||||
SetDefaultBroker(broker)
|
||||
|
||||
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
|
||||
cfg.Provider, cfg.Mode, opts.InstanceID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultBroker sets the default global broker
|
||||
func SetDefaultBroker(broker Broker) {
|
||||
brokerMu.Lock()
|
||||
defer brokerMu.Unlock()
|
||||
defaultBroker = broker
|
||||
}
|
||||
|
||||
// GetDefaultBroker returns the default global broker
|
||||
func GetDefaultBroker() Broker {
|
||||
brokerMu.RLock()
|
||||
defer brokerMu.RUnlock()
|
||||
return defaultBroker
|
||||
}
|
||||
|
||||
// IsInitialized returns true if the default broker is initialized
|
||||
func IsInitialized() bool {
|
||||
return GetDefaultBroker() != nil
|
||||
}
|
||||
|
||||
// Publish publishes an event using the default broker
|
||||
func Publish(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Publish(ctx, event)
|
||||
}
|
||||
|
||||
// PublishSync publishes an event synchronously using the default broker
|
||||
func PublishSync(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.PublishSync(ctx, event)
|
||||
}
|
||||
|
||||
// PublishAsync publishes an event asynchronously using the default broker
|
||||
func PublishAsync(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.PublishAsync(ctx, event)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to events using the default broker
|
||||
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return "", fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Subscribe(pattern, handler)
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from events using the default broker
|
||||
func Unsubscribe(id SubscriptionID) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Unsubscribe(id)
|
||||
}
|
||||
|
||||
// Stats returns statistics from the default broker
|
||||
func Stats(ctx context.Context) (*BrokerStats, error) {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return nil, fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Stats(ctx)
|
||||
}
|
||||
|
||||
// RegisterShutdown registers the broker's shutdown with a server manager
|
||||
// Call this from your application initialization code
|
||||
// Example: serverMgr.RegisterShutdownCallback(eventbroker.MakeShutdownCallback(broker))
|
||||
func MakeShutdownCallback(broker Broker) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
logger.Info("Shutting down event broker...")
|
||||
return broker.Stop(ctx)
|
||||
}
|
||||
}
|
||||
266
pkg/eventbroker/example_usage.go
Normal file
266
pkg/eventbroker/example_usage.go
Normal file
@@ -0,0 +1,266 @@
|
||||
// nolint
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Example demonstrates basic usage of the event broker
|
||||
func Example() {
|
||||
// 1. Create a memory provider
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "example-instance",
|
||||
MaxEvents: 1000,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
MaxAge: 1 * time.Hour,
|
||||
})
|
||||
|
||||
// 2. Create a broker
|
||||
broker, err := NewBroker(Options{
|
||||
Provider: provider,
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 5,
|
||||
BufferSize: 100,
|
||||
RetryPolicy: DefaultRetryPolicy(),
|
||||
InstanceID: "example-instance",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to create broker: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Start the broker
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
logger.Error("Failed to start broker: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := broker.Stop(context.Background())
|
||||
if err != nil {
|
||||
logger.Error("Failed to stop broker: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 4. Subscribe to events
|
||||
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// 5. Publish events
|
||||
ctx := context.Background()
|
||||
|
||||
// Database event
|
||||
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
|
||||
dbEvent.InstanceID = "example-instance"
|
||||
dbEvent.UserID = 123
|
||||
dbEvent.SessionID = "session-456"
|
||||
dbEvent.Schema = "public"
|
||||
dbEvent.Entity = "users"
|
||||
dbEvent.Operation = "create"
|
||||
dbEvent.SetPayload(map[string]interface{}{
|
||||
"id": 123,
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
})
|
||||
|
||||
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
|
||||
logger.Error("Failed to publish event: %v", err)
|
||||
}
|
||||
|
||||
// WebSocket event
|
||||
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
|
||||
wsEvent.InstanceID = "example-instance"
|
||||
wsEvent.UserID = 123
|
||||
wsEvent.SessionID = "session-456"
|
||||
wsEvent.SetPayload(map[string]interface{}{
|
||||
"room": "general",
|
||||
"message": "Hello, World!",
|
||||
})
|
||||
|
||||
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
|
||||
logger.Error("Failed to publish event: %v", err)
|
||||
}
|
||||
|
||||
// 6. Get statistics
|
||||
time.Sleep(1 * time.Second) // Wait for processing
|
||||
stats, _ := broker.Stats(ctx)
|
||||
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
|
||||
}
|
||||
|
||||
// ExampleWithHooks demonstrates integration with the hook system
|
||||
func ExampleWithHooks() {
|
||||
// This would typically be called in your main.go or initialization code
|
||||
// after setting up your restheadspec.Handler
|
||||
|
||||
// Pseudo-code (actual implementation would use real handler):
|
||||
/*
|
||||
broker := eventbroker.GetDefaultBroker()
|
||||
hookRegistry := handler.Hooks()
|
||||
|
||||
// Register CRUD hooks
|
||||
config := eventbroker.DefaultCRUDHookConfig()
|
||||
config.EnableRead = false // Disable read events for performance
|
||||
|
||||
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
|
||||
logger.Error("Failed to register CRUD hooks: %v", err)
|
||||
}
|
||||
|
||||
// Now all CRUD operations will automatically publish events
|
||||
*/
|
||||
}
|
||||
|
||||
// ExampleSubscriptionPatterns demonstrates different subscription patterns
|
||||
func ExampleSubscriptionPatterns() {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Pattern 1: Subscribe to all events from a specific entity
|
||||
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("User event: %s\n", event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 2: Subscribe to a specific operation across all entities
|
||||
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 3: Subscribe to all events in a schema
|
||||
broker.Subscribe("public.*.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 4: Subscribe to everything (use with caution)
|
||||
broker.Subscribe("*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Any event: %s\n", event.Type)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// ExampleErrorHandling demonstrates error handling in event handlers
|
||||
func ExampleErrorHandling() {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handler that may fail
|
||||
broker.Subscribe("public.users.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
// Simulate processing
|
||||
var user struct {
|
||||
ID int `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if err := event.GetPayload(&user); err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
// Validate
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
// Process (e.g., send email)
|
||||
logger.Info("Sending welcome email to %s", user.Email)
|
||||
|
||||
return nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// ExampleConfiguration demonstrates initializing from configuration
|
||||
func ExampleConfiguration() {
|
||||
// This would typically be in your main.go
|
||||
|
||||
// Pseudo-code:
|
||||
/*
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
if err := cfgMgr.Load(); err != nil {
|
||||
logger.Fatal("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := cfgMgr.GetConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Initialize event broker
|
||||
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||
logger.Fatal("Failed to initialize event broker: %v", err)
|
||||
}
|
||||
|
||||
// Use the default broker
|
||||
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
logger.Info("Created: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
*/
|
||||
}
|
||||
|
||||
// ExampleYAMLConfiguration shows example YAML configuration
|
||||
const ExampleYAMLConfiguration = `
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: memory # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
|
||||
# Memory provider is default, no additional config needed
|
||||
|
||||
# Redis provider (when provider: redis)
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
|
||||
# NATS provider (when provider: nats)
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "RESOLVESPEC_EVENTS"
|
||||
|
||||
# Database provider (when provider: database)
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "resolvespec_events"
|
||||
|
||||
# Retry policy
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 30s
|
||||
backoff_factor: 2.0
|
||||
`
|
||||
74
pkg/eventbroker/factory.go
Normal file
74
pkg/eventbroker/factory.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// NewProviderFromConfig creates a provider based on configuration
|
||||
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "memory":
|
||||
cleanupInterval := 5 * time.Minute
|
||||
if cfg.Database.PollInterval > 0 {
|
||||
cleanupInterval = cfg.Database.PollInterval
|
||||
}
|
||||
|
||||
return NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxEvents: 10000,
|
||||
CleanupInterval: cleanupInterval,
|
||||
}), nil
|
||||
|
||||
case "redis":
|
||||
return NewRedisProvider(RedisProviderConfig{
|
||||
Host: cfg.Redis.Host,
|
||||
Port: cfg.Redis.Port,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
StreamName: cfg.Redis.StreamName,
|
||||
ConsumerGroup: cfg.Redis.ConsumerGroup,
|
||||
ConsumerName: getInstanceID(cfg.InstanceID),
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxLen: cfg.Redis.MaxLen,
|
||||
})
|
||||
|
||||
case "nats":
|
||||
// NATS provider initialization
|
||||
// Note: Requires github.com/nats-io/nats.go dependency
|
||||
return NewNATSProvider(NATSProviderConfig{
|
||||
URL: cfg.NATS.URL,
|
||||
StreamName: cfg.NATS.StreamName,
|
||||
SubjectPrefix: "events",
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxAge: cfg.NATS.MaxAge,
|
||||
Storage: cfg.NATS.Storage, // "file" or "memory"
|
||||
})
|
||||
|
||||
case "database":
|
||||
// Database provider requires a database connection
|
||||
// This should be provided externally
|
||||
return nil, fmt.Errorf("database provider requires a database connection to be configured separately")
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// getInstanceID returns the instance ID, defaulting to hostname if not specified
|
||||
func getInstanceID(configID string) string {
|
||||
if configID != "" {
|
||||
return configID
|
||||
}
|
||||
|
||||
// Try to get hostname
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname
|
||||
}
|
||||
|
||||
// Fallback to a default
|
||||
return "resolvespec-instance"
|
||||
}
|
||||
17
pkg/eventbroker/handler.go
Normal file
17
pkg/eventbroker/handler.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package eventbroker
|
||||
|
||||
import "context"
|
||||
|
||||
// EventHandler processes an event
|
||||
type EventHandler interface {
|
||||
Handle(ctx context.Context, event *Event) error
|
||||
}
|
||||
|
||||
// EventHandlerFunc is a function adapter for EventHandler
|
||||
// This allows using regular functions as event handlers
|
||||
type EventHandlerFunc func(ctx context.Context, event *Event) error
|
||||
|
||||
// Handle implements EventHandler
|
||||
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
|
||||
return f(ctx, event)
|
||||
}
|
||||
137
pkg/eventbroker/hooks.go
Normal file
137
pkg/eventbroker/hooks.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// CRUDHookConfig configures which CRUD operations should trigger events
|
||||
type CRUDHookConfig struct {
|
||||
EnableCreate bool
|
||||
EnableRead bool
|
||||
EnableUpdate bool
|
||||
EnableDelete bool
|
||||
}
|
||||
|
||||
// DefaultCRUDHookConfig returns default configuration (all enabled)
|
||||
func DefaultCRUDHookConfig() *CRUDHookConfig {
|
||||
return &CRUDHookConfig{
|
||||
EnableCreate: true,
|
||||
EnableRead: false, // Typically disabled for performance
|
||||
EnableUpdate: true,
|
||||
EnableDelete: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterCRUDHooks registers event hooks for CRUD operations
|
||||
// This integrates with the restheadspec.HookRegistry to automatically
|
||||
// capture database events
|
||||
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
|
||||
if broker == nil {
|
||||
return fmt.Errorf("broker cannot be nil")
|
||||
}
|
||||
if hookRegistry == nil {
|
||||
return fmt.Errorf("hookRegistry cannot be nil")
|
||||
}
|
||||
if config == nil {
|
||||
config = DefaultCRUDHookConfig()
|
||||
}
|
||||
|
||||
// Create hook handler factory
|
||||
createHookHandler := func(operation string) restheadspec.HookFunc {
|
||||
return func(hookCtx *restheadspec.HookContext) error {
|
||||
// Get user context from Go context
|
||||
userCtx, ok := security.GetUserContext(hookCtx.Context)
|
||||
if !ok || userCtx == nil {
|
||||
logger.Debug("No user context found in hook")
|
||||
userCtx = &security.UserContext{} // Empty user context
|
||||
}
|
||||
|
||||
// Create event
|
||||
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
|
||||
event.InstanceID = broker.InstanceID()
|
||||
event.UserID = userCtx.UserID
|
||||
event.SessionID = userCtx.SessionID
|
||||
event.Schema = hookCtx.Schema
|
||||
event.Entity = hookCtx.Entity
|
||||
event.Operation = operation
|
||||
|
||||
// Set payload based on operation
|
||||
var payload interface{}
|
||||
switch operation {
|
||||
case "create":
|
||||
payload = hookCtx.Result
|
||||
case "read":
|
||||
payload = hookCtx.Result
|
||||
case "update":
|
||||
payload = map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
"data": hookCtx.Data,
|
||||
}
|
||||
case "delete":
|
||||
payload = map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
}
|
||||
}
|
||||
|
||||
if payload != nil {
|
||||
if err := event.SetPayload(payload); err != nil {
|
||||
logger.Error("Failed to set event payload: %v", err)
|
||||
payload = map[string]interface{}{"error": "failed to serialize payload"}
|
||||
event.Payload, _ = json.Marshal(payload)
|
||||
}
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
if userCtx.UserName != "" {
|
||||
event.Metadata["user_name"] = userCtx.UserName
|
||||
}
|
||||
if userCtx.Email != "" {
|
||||
event.Metadata["user_email"] = userCtx.Email
|
||||
}
|
||||
if len(userCtx.Roles) > 0 {
|
||||
event.Metadata["user_roles"] = userCtx.Roles
|
||||
}
|
||||
event.Metadata["table_name"] = hookCtx.TableName
|
||||
|
||||
// Publish asynchronously to not block CRUD operation
|
||||
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
|
||||
logger.Error("Failed to publish %s event for %s.%s: %v",
|
||||
operation, hookCtx.Schema, hookCtx.Entity, err)
|
||||
// Don't fail the CRUD operation if event publishing fails
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Published %s event for %s.%s (ID: %s)",
|
||||
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Register hooks based on configuration
|
||||
if config.EnableCreate {
|
||||
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
|
||||
logger.Info("Registered event hook for CREATE operations")
|
||||
}
|
||||
|
||||
if config.EnableRead {
|
||||
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
|
||||
logger.Info("Registered event hook for READ operations")
|
||||
}
|
||||
|
||||
if config.EnableUpdate {
|
||||
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
|
||||
logger.Info("Registered event hook for UPDATE operations")
|
||||
}
|
||||
|
||||
if config.EnableDelete {
|
||||
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
|
||||
logger.Info("Registered event hook for DELETE operations")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
pkg/eventbroker/metrics.go
Normal file
28
pkg/eventbroker/metrics.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
// recordEventPublished records an event publication metric
|
||||
func recordEventPublished(event *Event) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.RecordEventPublished(string(event.Source), event.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// recordEventProcessed records an event processing metric
|
||||
func recordEventProcessed(event *Event, duration time.Duration) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
|
||||
}
|
||||
}
|
||||
|
||||
// updateQueueSize updates the event queue size metric
|
||||
func updateQueueSize(size int64) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.UpdateEventQueueSize(size)
|
||||
}
|
||||
}
|
||||
70
pkg/eventbroker/provider.go
Normal file
70
pkg/eventbroker/provider.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider defines the storage backend interface for events
|
||||
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
|
||||
type Provider interface {
|
||||
// Store stores an event
|
||||
Store(ctx context.Context, event *Event) error
|
||||
|
||||
// Get retrieves an event by ID
|
||||
Get(ctx context.Context, id string) (*Event, error)
|
||||
|
||||
// List lists events with optional filters
|
||||
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
|
||||
|
||||
// Delete deletes an event by ID
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Used for cross-instance pub/sub
|
||||
// The channel is closed when the context is canceled or an error occurs
|
||||
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||
|
||||
// Publish publishes an event to all subscribers (for distributed providers)
|
||||
// For in-memory provider, this is the same as Store
|
||||
// For Redis/NATS/Database, this triggers cross-instance delivery
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
Close() error
|
||||
|
||||
// Stats returns provider statistics
|
||||
Stats(ctx context.Context) (*ProviderStats, error)
|
||||
}
|
||||
|
||||
// EventFilter defines filter criteria for listing events
|
||||
type EventFilter struct {
|
||||
Source *EventSource
|
||||
Status *EventStatus
|
||||
UserID *int
|
||||
Schema string
|
||||
Entity string
|
||||
Operation string
|
||||
InstanceID string
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ProviderStats contains statistics about the provider
|
||||
type ProviderStats struct {
|
||||
ProviderType string `json:"provider_type"`
|
||||
TotalEvents int64 `json:"total_events"`
|
||||
PendingEvents int64 `json:"pending_events"`
|
||||
ProcessingEvents int64 `json:"processing_events"`
|
||||
CompletedEvents int64 `json:"completed_events"`
|
||||
FailedEvents int64 `json:"failed_events"`
|
||||
EventsPublished int64 `json:"events_published"`
|
||||
EventsConsumed int64 `json:"events_consumed"`
|
||||
ActiveSubscribers int `json:"active_subscribers"`
|
||||
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
|
||||
}
|
||||
653
pkg/eventbroker/provider_database.go
Normal file
653
pkg/eventbroker/provider_database.go
Normal file
@@ -0,0 +1,653 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// DatabaseProvider implements Provider interface using SQL database
|
||||
// Features:
|
||||
// - Persistent event storage in database table
|
||||
// - Full SQL query support for event history
|
||||
// - PostgreSQL NOTIFY/LISTEN for real-time updates (optional)
|
||||
// - Polling-based consumption with configurable interval
|
||||
// - Good for audit trails and event replay
|
||||
type DatabaseProvider struct {
|
||||
db common.Database
|
||||
tableName string
|
||||
channel string // PostgreSQL NOTIFY channel name
|
||||
pollInterval time.Duration
|
||||
instanceID string
|
||||
useNotify bool // Whether to use PostgreSQL NOTIFY
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*dbSubscription
|
||||
|
||||
// Statistics
|
||||
stats DatabaseProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopPolling chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// DatabaseProviderStats contains statistics for the database provider
|
||||
type DatabaseProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
PollErrors atomic.Int64
|
||||
}
|
||||
|
||||
// dbSubscription represents a single database subscription
|
||||
type dbSubscription struct {
|
||||
pattern string
|
||||
ch chan *Event
|
||||
lastSeenID string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DatabaseProviderConfig configures the database provider
|
||||
type DatabaseProviderConfig struct {
|
||||
DB common.Database
|
||||
TableName string
|
||||
Channel string // PostgreSQL NOTIFY channel (optional)
|
||||
PollInterval time.Duration
|
||||
InstanceID string
|
||||
UseNotify bool // Enable PostgreSQL NOTIFY/LISTEN
|
||||
}
|
||||
|
||||
// NewDatabaseProvider creates a new database event provider
|
||||
func NewDatabaseProvider(cfg DatabaseProviderConfig) (*DatabaseProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.TableName == "" {
|
||||
cfg.TableName = "events"
|
||||
}
|
||||
if cfg.Channel == "" {
|
||||
cfg.Channel = "resolvespec_events"
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = 1 * time.Second
|
||||
}
|
||||
|
||||
dp := &DatabaseProvider{
|
||||
db: cfg.DB,
|
||||
tableName: cfg.TableName,
|
||||
channel: cfg.Channel,
|
||||
pollInterval: cfg.PollInterval,
|
||||
instanceID: cfg.InstanceID,
|
||||
useNotify: cfg.UseNotify,
|
||||
subscribers: make(map[string]*dbSubscription),
|
||||
stopPolling: make(chan struct{}),
|
||||
}
|
||||
|
||||
dp.isRunning.Store(true)
|
||||
|
||||
// Create table if it doesn't exist
|
||||
ctx := context.Background()
|
||||
if err := dp.createTable(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to create events table: %w", err)
|
||||
}
|
||||
|
||||
// Start polling goroutine for subscriptions
|
||||
dp.wg.Add(1)
|
||||
go dp.pollLoop()
|
||||
|
||||
logger.Info("Database provider initialized (table: %s, poll_interval: %v, notify: %v)",
|
||||
cfg.TableName, cfg.PollInterval, cfg.UseNotify)
|
||||
|
||||
return dp, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (dp *DatabaseProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal metadata to JSON
|
||||
metadataJSON, err := json.Marshal(event.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Insert event
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (
|
||||
id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10,
|
||||
$11, $12, $13,
|
||||
$14, $15, $16, $17
|
||||
)
|
||||
`, dp.tableName)
|
||||
|
||||
_, err = dp.db.Exec(ctx, query,
|
||||
event.ID, event.Source, event.Type, event.Status, event.RetryCount, event.Error,
|
||||
event.Payload, event.UserID, event.SessionID, event.InstanceID,
|
||||
event.Schema, event.Entity, event.Operation,
|
||||
event.CreatedAt, event.ProcessedAt, event.CompletedAt, metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert event: %w", err)
|
||||
}
|
||||
|
||||
dp.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
func (dp *DatabaseProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
event := &Event{}
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
// Query into individual fields
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
FROM %s
|
||||
WHERE id = $1
|
||||
`, dp.tableName)
|
||||
|
||||
var source, eventType, status, operation string
|
||||
|
||||
// Execute raw query
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query event: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
if err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan event: %w", err)
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (dp *DatabaseProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
query := fmt.Sprintf("SELECT id, source, type, status, retry_count, error, "+
|
||||
"payload, user_id, session_id, instance_id, "+
|
||||
"schema, entity, operation, "+
|
||||
"created_at, processed_at, completed_at, metadata "+
|
||||
"FROM %s WHERE 1=1", dp.tableName)
|
||||
|
||||
args := []interface{}{}
|
||||
argNum := 1
|
||||
|
||||
// Build WHERE clause
|
||||
if filter != nil {
|
||||
if filter.Source != nil {
|
||||
query += fmt.Sprintf(" AND source = $%d", argNum)
|
||||
args = append(args, string(*filter.Source))
|
||||
argNum++
|
||||
}
|
||||
if filter.Status != nil {
|
||||
query += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*filter.Status))
|
||||
argNum++
|
||||
}
|
||||
if filter.UserID != nil {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", argNum)
|
||||
args = append(args, *filter.UserID)
|
||||
argNum++
|
||||
}
|
||||
if filter.Schema != "" {
|
||||
query += fmt.Sprintf(" AND schema = $%d", argNum)
|
||||
args = append(args, filter.Schema)
|
||||
argNum++
|
||||
}
|
||||
if filter.Entity != "" {
|
||||
query += fmt.Sprintf(" AND entity = $%d", argNum)
|
||||
args = append(args, filter.Entity)
|
||||
argNum++
|
||||
}
|
||||
if filter.Operation != "" {
|
||||
query += fmt.Sprintf(" AND operation = $%d", argNum)
|
||||
args = append(args, filter.Operation)
|
||||
argNum++
|
||||
}
|
||||
if filter.InstanceID != "" {
|
||||
query += fmt.Sprintf(" AND instance_id = $%d", argNum)
|
||||
args = append(args, filter.InstanceID)
|
||||
argNum++
|
||||
}
|
||||
if filter.StartTime != nil {
|
||||
query += fmt.Sprintf(" AND created_at >= $%d", argNum)
|
||||
args = append(args, *filter.StartTime)
|
||||
argNum++
|
||||
}
|
||||
if filter.EndTime != nil {
|
||||
query += fmt.Sprintf(" AND created_at <= $%d", argNum)
|
||||
args = append(args, *filter.EndTime)
|
||||
argNum++
|
||||
}
|
||||
}
|
||||
|
||||
// Add ORDER BY
|
||||
query += " ORDER BY created_at DESC"
|
||||
|
||||
// Add LIMIT and OFFSET
|
||||
if filter != nil {
|
||||
if filter.Limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT $%d", argNum)
|
||||
args = append(args, filter.Limit)
|
||||
argNum++
|
||||
}
|
||||
if filter.Offset > 0 {
|
||||
query += fmt.Sprintf(" OFFSET $%d", argNum)
|
||||
args = append(args, filter.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*Event
|
||||
for rows.Next() {
|
||||
event := &Event{}
|
||||
var source, eventType, status, operation string
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to scan event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, event)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
func (dp *DatabaseProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE %s
|
||||
SET status = $1, error = $2
|
||||
WHERE id = $3
|
||||
`, dp.tableName)
|
||||
|
||||
_, err := dp.db.Exec(ctx, query, string(status), errorMsg, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
func (dp *DatabaseProvider) Delete(ctx context.Context, id string) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE id = $1", dp.tableName)
|
||||
|
||||
_, err := dp.db.Exec(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete event: %w", err)
|
||||
}
|
||||
|
||||
dp.stats.TotalEvents.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
func (dp *DatabaseProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &dbSubscription{
|
||||
pattern: pattern,
|
||||
ch: ch,
|
||||
lastSeenID: "",
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
dp.mu.Lock()
|
||||
dp.subscribers[pattern] = sub
|
||||
dp.stats.ActiveSubscribers.Add(1)
|
||||
dp.mu.Unlock()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (dp *DatabaseProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := dp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dp.stats.EventsPublished.Add(1)
|
||||
|
||||
// If using PostgreSQL NOTIFY, send notification
|
||||
if dp.useNotify {
|
||||
if err := dp.notify(ctx, event.ID); err != nil {
|
||||
logger.Warn("Failed to send NOTIFY: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (dp *DatabaseProvider) Close() error {
|
||||
if !dp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dp.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
dp.mu.Lock()
|
||||
for _, sub := range dp.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
dp.mu.Unlock()
|
||||
|
||||
// Stop polling
|
||||
close(dp.stopPolling)
|
||||
|
||||
// Wait for goroutines
|
||||
dp.wg.Wait()
|
||||
|
||||
logger.Info("Database provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (dp *DatabaseProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
// Get counts by status
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending,
|
||||
COUNT(*) FILTER (WHERE status = 'processing') as processing,
|
||||
COUNT(*) FILTER (WHERE status = 'completed') as completed,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') as failed,
|
||||
COUNT(*) as total
|
||||
FROM %s
|
||||
`, dp.tableName)
|
||||
|
||||
var pending, processing, completed, failed, total int64
|
||||
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get stats: %v", err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&pending, &processing, &completed, &failed, &total); err != nil {
|
||||
logger.Warn("Failed to scan stats: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderStats{
|
||||
ProviderType: "database",
|
||||
TotalEvents: total,
|
||||
PendingEvents: pending,
|
||||
ProcessingEvents: processing,
|
||||
CompletedEvents: completed,
|
||||
FailedEvents: failed,
|
||||
EventsPublished: dp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: dp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(dp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"table_name": dp.tableName,
|
||||
"poll_interval": dp.pollInterval.String(),
|
||||
"use_notify": dp.useNotify,
|
||||
"poll_errors": dp.stats.PollErrors.Load(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pollLoop periodically polls for new events
|
||||
func (dp *DatabaseProvider) pollLoop() {
|
||||
defer dp.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(dp.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
dp.pollEvents()
|
||||
case <-dp.stopPolling:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollEvents polls for new events and delivers to subscribers
|
||||
func (dp *DatabaseProvider) pollEvents() {
|
||||
dp.mu.RLock()
|
||||
subscribers := make([]*dbSubscription, 0, len(dp.subscribers))
|
||||
for _, sub := range dp.subscribers {
|
||||
subscribers = append(subscribers, sub)
|
||||
}
|
||||
dp.mu.RUnlock()
|
||||
|
||||
for _, sub := range subscribers {
|
||||
// Query for new events since last seen
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
FROM %s
|
||||
WHERE id > $1
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 100
|
||||
`, dp.tableName)
|
||||
|
||||
lastSeenID := sub.lastSeenID
|
||||
if lastSeenID == "" {
|
||||
lastSeenID = "00000000-0000-0000-0000-000000000000"
|
||||
}
|
||||
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(sub.ctx, query, lastSeenID)
|
||||
if err != nil {
|
||||
dp.stats.PollErrors.Add(1)
|
||||
logger.Warn("Failed to poll events: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
event := &Event{}
|
||||
var source, eventType, status, operation string
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to scan event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if event matches pattern
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- event:
|
||||
dp.stats.EventsConsumed.Add(1)
|
||||
sub.lastSeenID = event.ID
|
||||
case <-sub.ctx.Done():
|
||||
rows.Close()
|
||||
return
|
||||
default:
|
||||
// Channel full, skip
|
||||
logger.Warn("Subscriber channel full for pattern: %s", sub.pattern)
|
||||
}
|
||||
}
|
||||
|
||||
sub.lastSeenID = event.ID
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends a PostgreSQL NOTIFY message
|
||||
func (dp *DatabaseProvider) notify(ctx context.Context, eventID string) error {
|
||||
query := fmt.Sprintf("NOTIFY %s, '%s'", dp.channel, eventID)
|
||||
_, err := dp.db.Exec(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// createTable creates the events table if it doesn't exist
|
||||
func (dp *DatabaseProvider) createTable(ctx context.Context) error {
|
||||
query := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
source VARCHAR(50) NOT NULL,
|
||||
type VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
error TEXT,
|
||||
payload JSONB,
|
||||
user_id INTEGER,
|
||||
session_id VARCHAR(255),
|
||||
instance_id VARCHAR(255),
|
||||
schema VARCHAR(255),
|
||||
entity VARCHAR(255),
|
||||
operation VARCHAR(50),
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
processed_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
metadata JSONB
|
||||
)
|
||||
`, dp.tableName)
|
||||
|
||||
if _, err := dp.db.Exec(ctx, query); err != nil {
|
||||
return fmt.Errorf("failed to create table: %w", err)
|
||||
}
|
||||
|
||||
// Create indexes
|
||||
indexes := []string{
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_source ON %s(source)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_type ON %s(type)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_status ON %s(status)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_created_at ON %s(created_at)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_instance_id ON %s(instance_id)", dp.tableName, dp.tableName),
|
||||
}
|
||||
|
||||
for _, indexQuery := range indexes {
|
||||
if _, err := dp.db.Exec(ctx, indexQuery); err != nil {
|
||||
logger.Warn("Failed to create index: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
446
pkg/eventbroker/provider_memory.go
Normal file
446
pkg/eventbroker/provider_memory.go
Normal file
@@ -0,0 +1,446 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// MemoryProvider implements Provider interface using in-memory storage
|
||||
// Features:
|
||||
// - Thread-safe event storage with RW mutex
|
||||
// - LRU eviction when max events reached
|
||||
// - In-process pub/sub (not cross-instance)
|
||||
// - Automatic cleanup of old completed events
|
||||
type MemoryProvider struct {
|
||||
mu sync.RWMutex
|
||||
events map[string]*Event
|
||||
eventOrder []string // For LRU tracking
|
||||
subscribers map[string][]chan *Event
|
||||
instanceID string
|
||||
maxEvents int
|
||||
cleanupInterval time.Duration
|
||||
maxAge time.Duration
|
||||
|
||||
// Statistics
|
||||
stats MemoryProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopCleanup chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// MemoryProviderStats contains statistics for the memory provider
|
||||
type MemoryProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
PendingEvents atomic.Int64
|
||||
ProcessingEvents atomic.Int64
|
||||
CompletedEvents atomic.Int64
|
||||
FailedEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
Evictions atomic.Int64
|
||||
}
|
||||
|
||||
// MemoryProviderOptions configures the memory provider
|
||||
type MemoryProviderOptions struct {
|
||||
InstanceID string
|
||||
MaxEvents int
|
||||
CleanupInterval time.Duration
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
// NewMemoryProvider creates a new in-memory event provider
|
||||
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
|
||||
if opts.MaxEvents == 0 {
|
||||
opts.MaxEvents = 10000 // Default
|
||||
}
|
||||
if opts.CleanupInterval == 0 {
|
||||
opts.CleanupInterval = 5 * time.Minute // Default
|
||||
}
|
||||
if opts.MaxAge == 0 {
|
||||
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
|
||||
}
|
||||
|
||||
mp := &MemoryProvider{
|
||||
events: make(map[string]*Event),
|
||||
eventOrder: make([]string, 0),
|
||||
subscribers: make(map[string][]chan *Event),
|
||||
instanceID: opts.InstanceID,
|
||||
maxEvents: opts.MaxEvents,
|
||||
cleanupInterval: opts.CleanupInterval,
|
||||
maxAge: opts.MaxAge,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
mp.isRunning.Store(true)
|
||||
|
||||
// Start cleanup goroutine
|
||||
mp.wg.Add(1)
|
||||
go mp.cleanupLoop()
|
||||
|
||||
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
|
||||
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Check if we need to evict oldest events
|
||||
if len(mp.events) >= mp.maxEvents {
|
||||
mp.evictOldestLocked()
|
||||
}
|
||||
|
||||
// Store event
|
||||
mp.events[event.ID] = event.Clone()
|
||||
mp.eventOrder = append(mp.eventOrder, event.ID)
|
||||
|
||||
// Update statistics
|
||||
mp.stats.TotalEvents.Add(1)
|
||||
mp.updateStatusCountsLocked(event.Status, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
return event.Clone(), nil
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
var results []*Event
|
||||
|
||||
for _, event := range mp.events {
|
||||
if mp.matchesFilter(event, filter) {
|
||||
results = append(results, event.Clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// Update status counts
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
mp.updateStatusCountsLocked(status, 1)
|
||||
|
||||
// Update event
|
||||
event.Status = status
|
||||
if errorMsg != "" {
|
||||
event.Error = errorMsg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// Update counts
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
|
||||
// Delete event
|
||||
delete(mp.events, id)
|
||||
|
||||
// Remove from order tracking
|
||||
for i, eid := range mp.eventOrder {
|
||||
if eid == id {
|
||||
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Note: This is in-process only, not cross-instance
|
||||
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Create buffered channel for events
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
// Store subscriber
|
||||
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
|
||||
mp.stats.ActiveSubscribers.Add(1)
|
||||
|
||||
// Goroutine to clean up on context cancellation
|
||||
mp.wg.Add(1)
|
||||
go func() {
|
||||
defer mp.wg.Done()
|
||||
<-ctx.Done()
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Remove subscriber
|
||||
subs := mp.subscribers[pattern]
|
||||
for i, subCh := range subs {
|
||||
if subCh == ch {
|
||||
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mp.stats.ActiveSubscribers.Add(-1)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
logger.Debug("Stream created for pattern: %s", pattern)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := mp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp.stats.EventsPublished.Add(1)
|
||||
|
||||
// Notify subscribers
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
for pattern, channels := range mp.subscribers {
|
||||
if matchPattern(pattern, event.Type) {
|
||||
for _, ch := range channels {
|
||||
select {
|
||||
case ch <- event.Clone():
|
||||
mp.stats.EventsConsumed.Add(1)
|
||||
default:
|
||||
// Channel full, skip
|
||||
logger.Warn("Subscriber channel full for pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (mp *MemoryProvider) Close() error {
|
||||
if !mp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
mp.isRunning.Store(false)
|
||||
|
||||
// Stop cleanup loop
|
||||
close(mp.stopCleanup)
|
||||
|
||||
// Wait for goroutines
|
||||
mp.wg.Wait()
|
||||
|
||||
// Close all subscriber channels
|
||||
mp.mu.Lock()
|
||||
for _, channels := range mp.subscribers {
|
||||
for _, ch := range channels {
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
mp.subscribers = make(map[string][]chan *Event)
|
||||
mp.mu.Unlock()
|
||||
|
||||
logger.Info("Memory provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
return &ProviderStats{
|
||||
ProviderType: "memory",
|
||||
TotalEvents: mp.stats.TotalEvents.Load(),
|
||||
PendingEvents: mp.stats.PendingEvents.Load(),
|
||||
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
|
||||
CompletedEvents: mp.stats.CompletedEvents.Load(),
|
||||
FailedEvents: mp.stats.FailedEvents.Load(),
|
||||
EventsPublished: mp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: mp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"max_events": mp.maxEvents,
|
||||
"cleanup_interval": mp.cleanupInterval.String(),
|
||||
"max_age": mp.maxAge.String(),
|
||||
"evictions": mp.stats.Evictions.Load(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up old completed events
|
||||
func (mp *MemoryProvider) cleanupLoop() {
|
||||
defer mp.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(mp.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mp.cleanup()
|
||||
case <-mp.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old completed/failed events
|
||||
func (mp *MemoryProvider) cleanup() {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-mp.maxAge)
|
||||
removed := 0
|
||||
|
||||
for id, event := range mp.events {
|
||||
// Only clean up completed or failed events that are old
|
||||
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
|
||||
event.CreatedAt.Before(cutoff) {
|
||||
|
||||
delete(mp.events, id)
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
|
||||
// Remove from order tracking
|
||||
for i, eid := range mp.eventOrder {
|
||||
if eid == id {
|
||||
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
logger.Debug("Cleanup removed %d old events", removed)
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestLocked evicts the oldest event (LRU)
|
||||
// Caller must hold write lock
|
||||
func (mp *MemoryProvider) evictOldestLocked() {
|
||||
if len(mp.eventOrder) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get oldest event ID
|
||||
oldestID := mp.eventOrder[0]
|
||||
mp.eventOrder = mp.eventOrder[1:]
|
||||
|
||||
// Remove event
|
||||
if event, exists := mp.events[oldestID]; exists {
|
||||
delete(mp.events, oldestID)
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
mp.stats.Evictions.Add(1)
|
||||
|
||||
logger.Debug("Evicted oldest event: %s", oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// updateStatusCountsLocked updates status statistics
|
||||
// Caller must hold write lock
|
||||
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
|
||||
switch status {
|
||||
case EventStatusPending:
|
||||
mp.stats.PendingEvents.Add(delta)
|
||||
case EventStatusProcessing:
|
||||
mp.stats.ProcessingEvents.Add(delta)
|
||||
case EventStatusCompleted:
|
||||
mp.stats.CompletedEvents.Add(delta)
|
||||
case EventStatusFailed:
|
||||
mp.stats.FailedEvents.Add(delta)
|
||||
}
|
||||
}
|
||||
419
pkg/eventbroker/provider_memory_test.go
Normal file
419
pkg/eventbroker/provider_memory_test.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewMemoryProvider(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 100,
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
})
|
||||
|
||||
if provider == nil {
|
||||
t.Fatal("Expected non-nil provider")
|
||||
}
|
||||
|
||||
stats, err := provider.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.ProviderType != "memory" {
|
||||
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderPublishAndGet(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.UserID = 123
|
||||
|
||||
// Publish event
|
||||
if err := provider.Publish(context.Background(), event); err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
|
||||
// Get event
|
||||
retrieved, err := provider.Get(context.Background(), event.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID != event.ID {
|
||||
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
|
||||
}
|
||||
if retrieved.UserID != 123 {
|
||||
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderGetNonExistent(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
_, err := provider.Get(context.Background(), "non-existent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting non-existent event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderUpdateStatus(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Update status to processing
|
||||
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
retrieved, _ := provider.Get(context.Background(), event.ID)
|
||||
if retrieved.Status != EventStatusProcessing {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
|
||||
}
|
||||
|
||||
// Update status to failed with error
|
||||
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
retrieved, _ = provider.Get(context.Background(), event.ID)
|
||||
if retrieved.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
|
||||
}
|
||||
if retrieved.Error != "test error" {
|
||||
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderList(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
// List all events
|
||||
events, err := provider.List(context.Background(), &EventFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 5 {
|
||||
t.Errorf("Expected 5 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderListWithFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events with different types
|
||||
event1 := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
event3 := NewEvent(EventSourceWebSocket, "chat.message")
|
||||
provider.Publish(context.Background(), event3)
|
||||
|
||||
// Filter by source
|
||||
source := EventSourceDatabase
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
Source: &source,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Errorf("Expected 2 events with database source, got %d", len(events))
|
||||
}
|
||||
|
||||
// Filter by status
|
||||
status := EventStatusPending
|
||||
events, err = provider.List(context.Background(), &EventFilter{
|
||||
Status: &status,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Errorf("Expected 3 events with pending status, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderListWithLimit(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 10; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
// List with limit
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
Limit: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 5 {
|
||||
t.Errorf("Expected 5 events (limited), got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderDelete(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Delete event
|
||||
err := provider.Delete(context.Background(), event.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify deleted
|
||||
_, err = provider.Get(context.Background(), event.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting deleted event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderLRUEviction(t *testing.T) {
|
||||
// Create provider with small max events
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 3,
|
||||
})
|
||||
|
||||
// Publish 5 events
|
||||
events := make([]*Event, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
events[i] = NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), events[i])
|
||||
}
|
||||
|
||||
// First 2 events should be evicted
|
||||
_, err := provider.Get(context.Background(), events[0].ID)
|
||||
if err == nil {
|
||||
t.Error("Expected first event to be evicted")
|
||||
}
|
||||
|
||||
_, err = provider.Get(context.Background(), events[1].ID)
|
||||
if err == nil {
|
||||
t.Error("Expected second event to be evicted")
|
||||
}
|
||||
|
||||
// Last 3 events should still exist
|
||||
for i := 2; i < 5; i++ {
|
||||
_, err := provider.Get(context.Background(), events[i].ID)
|
||||
if err != nil {
|
||||
t.Errorf("Expected event %d to still exist", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderCleanup(t *testing.T) {
|
||||
// Create provider with short cleanup interval
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
MaxAge: 200 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Publish and complete an event
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
// Event should be cleaned up
|
||||
_, err := provider.Get(context.Background(), event.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected event to be cleaned up")
|
||||
}
|
||||
|
||||
provider.Close()
|
||||
}
|
||||
|
||||
func TestMemoryProviderStats(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 100,
|
||||
})
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
stats, err := provider.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.ProviderType != "memory" {
|
||||
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||
}
|
||||
if stats.TotalEvents != 5 {
|
||||
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderClose(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Close provider
|
||||
err := provider.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup goroutine should be stopped
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestMemoryProviderConcurrency(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Concurrent publish
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all events were stored
|
||||
events, _ := provider.List(context.Background(), &EventFilter{})
|
||||
if len(events) != 10 {
|
||||
t.Errorf("Expected 10 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderStream(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Stream is implemented for memory provider (in-process pub/sub)
|
||||
ch, err := provider.Stream(context.Background(), "test.*")
|
||||
if err != nil {
|
||||
t.Fatalf("Stream failed: %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Error("Expected non-nil channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events at different times
|
||||
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
event3 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event3)
|
||||
|
||||
// Filter by time range
|
||||
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
StartTime: &startTime,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
// Should get events 2 and 3
|
||||
if len(events) != 2 {
|
||||
t.Errorf("Expected 2 events after start time, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events with different instance IDs
|
||||
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||
event1.InstanceID = "instance-1"
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||
event2.InstanceID = "instance-2"
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
// Filter by instance ID
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
InstanceID: "instance-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 1 {
|
||||
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
|
||||
}
|
||||
if events[0].InstanceID != "instance-1" {
|
||||
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
|
||||
}
|
||||
}
|
||||
565
pkg/eventbroker/provider_nats.go
Normal file
565
pkg/eventbroker/provider_nats.go
Normal file
@@ -0,0 +1,565 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// NATSProvider implements Provider interface using NATS JetStream
|
||||
// Features:
|
||||
// - Persistent event storage using JetStream
|
||||
// - Cross-instance pub/sub using NATS subjects
|
||||
// - Wildcard subscription support
|
||||
// - Durable consumers for event replay
|
||||
// - At-least-once delivery semantics
|
||||
type NATSProvider struct {
|
||||
nc *nats.Conn
|
||||
js jetstream.JetStream
|
||||
stream jetstream.Stream
|
||||
streamName string
|
||||
subjectPrefix string
|
||||
instanceID string
|
||||
maxAge time.Duration
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*natsSubscription
|
||||
|
||||
// Statistics
|
||||
stats NATSProviderStats
|
||||
|
||||
// Lifecycle
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// NATSProviderStats contains statistics for the NATS provider
|
||||
type NATSProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
ConsumerErrors atomic.Int64
|
||||
}
|
||||
|
||||
// natsSubscription represents a single NATS subscription
|
||||
type natsSubscription struct {
|
||||
pattern string
|
||||
consumer jetstream.Consumer
|
||||
ch chan *Event
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NATSProviderConfig configures the NATS provider
|
||||
type NATSProviderConfig struct {
|
||||
URL string
|
||||
StreamName string
|
||||
SubjectPrefix string // e.g., "events"
|
||||
InstanceID string
|
||||
MaxAge time.Duration // How long to keep events
|
||||
Storage string // "file" or "memory"
|
||||
}
|
||||
|
||||
// NewNATSProvider creates a new NATS event provider
|
||||
func NewNATSProvider(cfg NATSProviderConfig) (*NATSProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.URL == "" {
|
||||
cfg.URL = nats.DefaultURL
|
||||
}
|
||||
if cfg.StreamName == "" {
|
||||
cfg.StreamName = "RESOLVESPEC_EVENTS"
|
||||
}
|
||||
if cfg.SubjectPrefix == "" {
|
||||
cfg.SubjectPrefix = "events"
|
||||
}
|
||||
if cfg.MaxAge == 0 {
|
||||
cfg.MaxAge = 7 * 24 * time.Hour // 7 days
|
||||
}
|
||||
if cfg.Storage == "" {
|
||||
cfg.Storage = "file"
|
||||
}
|
||||
|
||||
// Connect to NATS
|
||||
nc, err := nats.Connect(cfg.URL,
|
||||
nats.Name("resolvespec-eventbroker-"+cfg.InstanceID),
|
||||
nats.Timeout(5*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
|
||||
// Create JetStream context
|
||||
js, err := jetstream.New(nc)
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
return nil, fmt.Errorf("failed to create JetStream context: %w", err)
|
||||
}
|
||||
|
||||
np := &NATSProvider{
|
||||
nc: nc,
|
||||
js: js,
|
||||
streamName: cfg.StreamName,
|
||||
subjectPrefix: cfg.SubjectPrefix,
|
||||
instanceID: cfg.InstanceID,
|
||||
maxAge: cfg.MaxAge,
|
||||
subscribers: make(map[string]*natsSubscription),
|
||||
}
|
||||
|
||||
np.isRunning.Store(true)
|
||||
|
||||
// Create or update stream
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Determine storage type
|
||||
var storage jetstream.StorageType
|
||||
if cfg.Storage == "memory" {
|
||||
storage = jetstream.MemoryStorage
|
||||
} else {
|
||||
storage = jetstream.FileStorage
|
||||
}
|
||||
|
||||
if err := np.ensureStream(ctx, storage); err != nil {
|
||||
nc.Close()
|
||||
return nil, fmt.Errorf("failed to create stream: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("NATS provider initialized (stream: %s, subject: %s.*, url: %s)",
|
||||
cfg.StreamName, cfg.SubjectPrefix, cfg.URL)
|
||||
|
||||
return np, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (np *NATSProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal event to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event: %w", err)
|
||||
}
|
||||
|
||||
// Publish to NATS subject
|
||||
// Subject format: events.{source}.{schema}.{entity}.{operation}
|
||||
subject := np.buildSubject(event)
|
||||
|
||||
msg := &nats.Msg{
|
||||
Subject: subject,
|
||||
Data: data,
|
||||
Header: nats.Header{
|
||||
"Event-ID": []string{event.ID},
|
||||
"Event-Type": []string{event.Type},
|
||||
"Event-Source": []string{string(event.Source)},
|
||||
"Event-Status": []string{string(event.Status)},
|
||||
"Instance-ID": []string{event.InstanceID},
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := np.js.PublishMsg(ctx, msg); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
np.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
// Note: This is inefficient with JetStream - consider using a separate KV store for lookups
|
||||
func (np *NATSProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
// We need to scan messages which is not ideal
|
||||
// For production, consider using NATS KV store for fast lookups
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: "get-" + id,
|
||||
FilterSubject: np.subjectPrefix + ".>",
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
// Fetch messages in batches
|
||||
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch messages: %w", err)
|
||||
}
|
||||
|
||||
for msg := range msgs.Messages() {
|
||||
if msg.Headers().Get("Event-ID") == id {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
_ = msg.Nak()
|
||||
continue
|
||||
}
|
||||
_ = msg.Ack()
|
||||
|
||||
// Delete temporary consumer
|
||||
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
|
||||
|
||||
return &event, nil
|
||||
}
|
||||
_ = msg.Ack()
|
||||
}
|
||||
|
||||
// Delete temporary consumer
|
||||
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
|
||||
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (np *NATSProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
var results []*Event
|
||||
|
||||
// Create temporary consumer
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: fmt.Sprintf("list-%d", time.Now().UnixNano()),
|
||||
FilterSubject: np.subjectPrefix + ".>",
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
defer func() { _ = np.stream.DeleteConsumer(ctx, consumer.CachedInfo().Name) }()
|
||||
|
||||
// Fetch messages in batches
|
||||
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch messages: %w", err)
|
||||
}
|
||||
|
||||
for msg := range msgs.Messages() {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
_ = msg.Nak()
|
||||
continue
|
||||
}
|
||||
|
||||
if np.matchesFilter(&event, filter) {
|
||||
results = append(results, &event)
|
||||
}
|
||||
|
||||
_ = msg.Ack()
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
// Note: NATS streams are append-only, so we publish a status update event
|
||||
func (np *NATSProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
// Publish a status update message
|
||||
subject := fmt.Sprintf("%s.status.%s", np.subjectPrefix, id)
|
||||
|
||||
statusUpdate := map[string]interface{}{
|
||||
"event_id": id,
|
||||
"status": string(status),
|
||||
"error": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(statusUpdate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal status update: %w", err)
|
||||
}
|
||||
|
||||
if _, err := np.js.Publish(ctx, subject, data); err != nil {
|
||||
return fmt.Errorf("failed to publish status update: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
// Note: NATS streams don't support deletion - this just marks it in a separate subject
|
||||
func (np *NATSProvider) Delete(ctx context.Context, id string) error {
|
||||
subject := fmt.Sprintf("%s.deleted.%s", np.subjectPrefix, id)
|
||||
|
||||
deleteMsg := map[string]interface{}{
|
||||
"event_id": id,
|
||||
"deleted_at": time.Now(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(deleteMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal delete message: %w", err)
|
||||
}
|
||||
|
||||
if _, err := np.js.Publish(ctx, subject, data); err != nil {
|
||||
return fmt.Errorf("failed to publish delete message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
func (np *NATSProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
// Convert glob pattern to NATS subject pattern
|
||||
natsSubject := np.patternToSubject(pattern)
|
||||
|
||||
// Create durable consumer
|
||||
consumerName := fmt.Sprintf("consumer-%s-%d", np.instanceID, time.Now().UnixNano())
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: consumerName,
|
||||
FilterSubject: natsSubject,
|
||||
DeliverPolicy: jetstream.DeliverNewPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
AckWait: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &natsSubscription{
|
||||
pattern: pattern,
|
||||
consumer: consumer,
|
||||
ch: ch,
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
np.mu.Lock()
|
||||
np.subscribers[pattern] = sub
|
||||
np.stats.ActiveSubscribers.Add(1)
|
||||
np.mu.Unlock()
|
||||
|
||||
// Start consumer goroutine
|
||||
np.wg.Add(1)
|
||||
go np.consumeMessages(sub)
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (np *NATSProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := np.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
np.stats.EventsPublished.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (np *NATSProvider) Close() error {
|
||||
if !np.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
np.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
np.mu.Lock()
|
||||
for _, sub := range np.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
np.mu.Unlock()
|
||||
|
||||
// Wait for goroutines
|
||||
np.wg.Wait()
|
||||
|
||||
// Close NATS connection
|
||||
np.nc.Close()
|
||||
|
||||
logger.Info("NATS provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (np *NATSProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
streamInfo, err := np.stream.Info(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get stream info: %v", err)
|
||||
}
|
||||
|
||||
stats := &ProviderStats{
|
||||
ProviderType: "nats",
|
||||
TotalEvents: np.stats.TotalEvents.Load(),
|
||||
EventsPublished: np.stats.EventsPublished.Load(),
|
||||
EventsConsumed: np.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(np.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"stream_name": np.streamName,
|
||||
"subject_prefix": np.subjectPrefix,
|
||||
"max_age": np.maxAge.String(),
|
||||
"consumer_errors": np.stats.ConsumerErrors.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
if streamInfo != nil {
|
||||
stats.ProviderSpecific["messages"] = streamInfo.State.Msgs
|
||||
stats.ProviderSpecific["bytes"] = streamInfo.State.Bytes
|
||||
stats.ProviderSpecific["consumers"] = streamInfo.State.Consumers
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ensureStream creates or updates the JetStream stream
|
||||
func (np *NATSProvider) ensureStream(ctx context.Context, storage jetstream.StorageType) error {
|
||||
streamConfig := jetstream.StreamConfig{
|
||||
Name: np.streamName,
|
||||
Subjects: []string{np.subjectPrefix + ".>"},
|
||||
MaxAge: np.maxAge,
|
||||
Storage: storage,
|
||||
Retention: jetstream.LimitsPolicy,
|
||||
Discard: jetstream.DiscardOld,
|
||||
}
|
||||
|
||||
stream, err := np.js.CreateStream(ctx, streamConfig)
|
||||
if err != nil {
|
||||
// Try to update if already exists
|
||||
stream, err = np.js.UpdateStream(ctx, streamConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create/update stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
np.stream = stream
|
||||
return nil
|
||||
}
|
||||
|
||||
// consumeMessages consumes messages from NATS for a subscription
|
||||
func (np *NATSProvider) consumeMessages(sub *natsSubscription) {
|
||||
defer np.wg.Done()
|
||||
defer close(sub.ch)
|
||||
defer func() {
|
||||
np.mu.Lock()
|
||||
delete(np.subscribers, sub.pattern)
|
||||
np.stats.ActiveSubscribers.Add(-1)
|
||||
np.mu.Unlock()
|
||||
}()
|
||||
|
||||
logger.Debug("Starting NATS consumer for pattern: %s", sub.pattern)
|
||||
|
||||
// Consume messages
|
||||
cc, err := sub.consumer.Consume(func(msg jetstream.Msg) {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
_ = msg.Nak()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if event matches pattern (additional filtering)
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- &event:
|
||||
np.stats.EventsConsumed.Add(1)
|
||||
_ = msg.Ack()
|
||||
case <-sub.ctx.Done():
|
||||
_ = msg.Nak()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_ = msg.Ack()
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
np.stats.ConsumerErrors.Add(1)
|
||||
logger.Error("Failed to start consumer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for context cancellation
|
||||
<-sub.ctx.Done()
|
||||
|
||||
// Stop consuming
|
||||
cc.Stop()
|
||||
|
||||
logger.Debug("NATS consumer stopped for pattern: %s", sub.pattern)
|
||||
}
|
||||
|
||||
// buildSubject creates a NATS subject from an event
|
||||
// Format: events.{source}.{schema}.{entity}.{operation}
|
||||
func (np *NATSProvider) buildSubject(event *Event) string {
|
||||
return fmt.Sprintf("%s.%s.%s.%s.%s",
|
||||
np.subjectPrefix,
|
||||
event.Source,
|
||||
event.Schema,
|
||||
event.Entity,
|
||||
event.Operation,
|
||||
)
|
||||
}
|
||||
|
||||
// patternToSubject converts a glob pattern to NATS subject pattern
|
||||
// Examples:
|
||||
// - "*" -> "events.>"
|
||||
// - "public.users.*" -> "events.*.public.users.*"
|
||||
// - "public.*.*" -> "events.*.public.*.*"
|
||||
func (np *NATSProvider) patternToSubject(pattern string) string {
|
||||
if pattern == "*" {
|
||||
return np.subjectPrefix + ".>"
|
||||
}
|
||||
|
||||
// For specific patterns, we need to match the event type structure
|
||||
// Event type: schema.entity.operation
|
||||
// NATS subject: events.{source}.{schema}.{entity}.{operation}
|
||||
// We use wildcard for source since pattern doesn't include it
|
||||
return fmt.Sprintf("%s.*.%s", np.subjectPrefix, pattern)
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (np *NATSProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
541
pkg/eventbroker/provider_redis.go
Normal file
541
pkg/eventbroker/provider_redis.go
Normal file
@@ -0,0 +1,541 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// RedisProvider implements Provider interface using Redis Streams
|
||||
// Features:
|
||||
// - Persistent event storage using Redis Streams
|
||||
// - Cross-instance pub/sub using consumer groups
|
||||
// - Pattern-based subscription routing
|
||||
// - Automatic stream trimming to prevent unbounded growth
|
||||
type RedisProvider struct {
|
||||
client *redis.Client
|
||||
streamName string
|
||||
consumerGroup string
|
||||
consumerName string
|
||||
instanceID string
|
||||
maxLen int64
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*redisSubscription
|
||||
|
||||
// Statistics
|
||||
stats RedisProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopListeners chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// RedisProviderStats contains statistics for the Redis provider
|
||||
type RedisProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
ConsumerErrors atomic.Int64
|
||||
}
|
||||
|
||||
// redisSubscription represents a single subscription
|
||||
type redisSubscription struct {
|
||||
pattern string
|
||||
ch chan *Event
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// RedisProviderConfig configures the Redis provider
|
||||
type RedisProviderConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int
|
||||
StreamName string
|
||||
ConsumerGroup string
|
||||
ConsumerName string
|
||||
InstanceID string
|
||||
MaxLen int64 // Maximum stream length (0 = unlimited)
|
||||
}
|
||||
|
||||
// NewRedisProvider creates a new Redis event provider
|
||||
func NewRedisProvider(cfg RedisProviderConfig) (*RedisProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = "localhost"
|
||||
}
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 6379
|
||||
}
|
||||
if cfg.StreamName == "" {
|
||||
cfg.StreamName = "resolvespec:events"
|
||||
}
|
||||
if cfg.ConsumerGroup == "" {
|
||||
cfg.ConsumerGroup = "resolvespec-workers"
|
||||
}
|
||||
if cfg.ConsumerName == "" {
|
||||
cfg.ConsumerName = cfg.InstanceID
|
||||
}
|
||||
if cfg.MaxLen == 0 {
|
||||
cfg.MaxLen = 10000 // Default max stream length
|
||||
}
|
||||
|
||||
// Create Redis client
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
PoolSize: 10,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
rp := &RedisProvider{
|
||||
client: client,
|
||||
streamName: cfg.StreamName,
|
||||
consumerGroup: cfg.ConsumerGroup,
|
||||
consumerName: cfg.ConsumerName,
|
||||
instanceID: cfg.InstanceID,
|
||||
maxLen: cfg.MaxLen,
|
||||
subscribers: make(map[string]*redisSubscription),
|
||||
stopListeners: make(chan struct{}),
|
||||
}
|
||||
|
||||
rp.isRunning.Store(true)
|
||||
|
||||
// Create consumer group if it doesn't exist
|
||||
if err := rp.ensureConsumerGroup(ctx); err != nil {
|
||||
logger.Warn("Failed to create consumer group: %v (may already exist)", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis provider initialized (stream: %s, consumer_group: %s, consumer: %s)",
|
||||
cfg.StreamName, cfg.ConsumerGroup, cfg.ConsumerName)
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (rp *RedisProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal event to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event: %w", err)
|
||||
}
|
||||
|
||||
// Store in Redis Stream
|
||||
args := &redis.XAddArgs{
|
||||
Stream: rp.streamName,
|
||||
MaxLen: rp.maxLen,
|
||||
Approx: true, // Use approximate trimming for better performance
|
||||
Values: map[string]interface{}{
|
||||
"event": data,
|
||||
"id": event.ID,
|
||||
"type": event.Type,
|
||||
"source": string(event.Source),
|
||||
"status": string(event.Status),
|
||||
"instance_id": event.InstanceID,
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := rp.client.XAdd(ctx, args).Result(); err != nil {
|
||||
return fmt.Errorf("failed to add event to stream: %w", err)
|
||||
}
|
||||
|
||||
rp.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
// Note: This scans the stream which can be slow for large streams
|
||||
// Consider using a separate hash for fast lookups if needed
|
||||
func (rp *RedisProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
// Scan stream for event with matching ID
|
||||
args := &redis.XReadArgs{
|
||||
Streams: []string{rp.streamName, "0"},
|
||||
Count: 1000, // Read in batches
|
||||
}
|
||||
|
||||
for {
|
||||
streams, err := rp.client.XRead(ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
|
||||
if len(streams) == 0 {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
// Check if this is the event we're looking for
|
||||
if eventID, ok := message.Values["id"].(string); ok && eventID == id {
|
||||
// Parse event
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
|
||||
}
|
||||
return &event, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we've read messages, update start position for next iteration
|
||||
if len(stream.Messages) > 0 {
|
||||
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
|
||||
} else {
|
||||
// No more messages
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
// Note: This scans the entire stream which can be slow
|
||||
// Consider using time-based or ID-based ranges for better performance
|
||||
func (rp *RedisProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
var results []*Event
|
||||
|
||||
// Read from stream
|
||||
args := &redis.XReadArgs{
|
||||
Streams: []string{rp.streamName, "0"},
|
||||
Count: 1000,
|
||||
}
|
||||
|
||||
for {
|
||||
streams, err := rp.client.XRead(ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
|
||||
if len(streams) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if rp.matchesFilter(&event, filter) {
|
||||
results = append(results, &event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update start position for next iteration
|
||||
if len(stream.Messages) > 0 {
|
||||
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
|
||||
} else {
|
||||
// No more messages
|
||||
goto done
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
// Note: Redis Streams are append-only, so we need to store status updates separately
|
||||
// This uses a separate hash for status tracking
|
||||
func (rp *RedisProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"status": string(status),
|
||||
"updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if errorMsg != "" {
|
||||
fields["error"] = errorMsg
|
||||
}
|
||||
|
||||
if err := rp.client.HSet(ctx, statusKey, fields).Err(); err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
// Set TTL on status key to prevent unbounded growth
|
||||
rp.client.Expire(ctx, statusKey, 7*24*time.Hour) // 7 days
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
// Note: Redis Streams don't support deletion by field value
|
||||
// This marks the event as deleted in a separate set
|
||||
func (rp *RedisProvider) Delete(ctx context.Context, id string) error {
|
||||
deletedKey := fmt.Sprintf("%s:deleted", rp.streamName)
|
||||
|
||||
if err := rp.client.SAdd(ctx, deletedKey, id).Err(); err != nil {
|
||||
return fmt.Errorf("failed to mark event as deleted: %w", err)
|
||||
}
|
||||
|
||||
// Also delete the status hash if it exists
|
||||
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
|
||||
rp.client.Del(ctx, statusKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Uses Redis Streams consumer group for distributed processing
|
||||
func (rp *RedisProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &redisSubscription{
|
||||
pattern: pattern,
|
||||
ch: ch,
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
rp.mu.Lock()
|
||||
rp.subscribers[pattern] = sub
|
||||
rp.stats.ActiveSubscribers.Add(1)
|
||||
rp.mu.Unlock()
|
||||
|
||||
// Start consumer goroutine
|
||||
rp.wg.Add(1)
|
||||
go rp.consumeStream(sub)
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers (cross-instance)
|
||||
func (rp *RedisProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := rp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rp.stats.EventsPublished.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (rp *RedisProvider) Close() error {
|
||||
if !rp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
rp.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
rp.mu.Lock()
|
||||
for _, sub := range rp.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
rp.mu.Unlock()
|
||||
|
||||
// Stop listeners
|
||||
close(rp.stopListeners)
|
||||
|
||||
// Wait for goroutines
|
||||
rp.wg.Wait()
|
||||
|
||||
// Close Redis client
|
||||
if err := rp.client.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close Redis client: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (rp *RedisProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
// Get stream info
|
||||
streamInfo, err := rp.client.XInfoStream(ctx, rp.streamName).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
logger.Warn("Failed to get stream info: %v", err)
|
||||
}
|
||||
|
||||
stats := &ProviderStats{
|
||||
ProviderType: "redis",
|
||||
TotalEvents: rp.stats.TotalEvents.Load(),
|
||||
EventsPublished: rp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: rp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(rp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"stream_name": rp.streamName,
|
||||
"consumer_group": rp.consumerGroup,
|
||||
"consumer_name": rp.consumerName,
|
||||
"max_len": rp.maxLen,
|
||||
"consumer_errors": rp.stats.ConsumerErrors.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
if streamInfo != nil {
|
||||
stats.ProviderSpecific["stream_length"] = streamInfo.Length
|
||||
stats.ProviderSpecific["first_entry_id"] = streamInfo.FirstEntry.ID
|
||||
stats.ProviderSpecific["last_entry_id"] = streamInfo.LastEntry.ID
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// consumeStream consumes events from the Redis Stream for a subscription
|
||||
func (rp *RedisProvider) consumeStream(sub *redisSubscription) {
|
||||
defer rp.wg.Done()
|
||||
defer close(sub.ch)
|
||||
defer func() {
|
||||
rp.mu.Lock()
|
||||
delete(rp.subscribers, sub.pattern)
|
||||
rp.stats.ActiveSubscribers.Add(-1)
|
||||
rp.mu.Unlock()
|
||||
}()
|
||||
|
||||
logger.Debug("Starting stream consumer for pattern: %s", sub.pattern)
|
||||
|
||||
// Use consumer group for distributed processing
|
||||
for {
|
||||
select {
|
||||
case <-sub.ctx.Done():
|
||||
logger.Debug("Stream consumer stopped for pattern: %s", sub.pattern)
|
||||
return
|
||||
default:
|
||||
// Read from consumer group
|
||||
args := &redis.XReadGroupArgs{
|
||||
Group: rp.consumerGroup,
|
||||
Consumer: rp.consumerName,
|
||||
Streams: []string{rp.streamName, ">"},
|
||||
Count: 10,
|
||||
Block: 1 * time.Second,
|
||||
}
|
||||
|
||||
streams, err := rp.client.XReadGroup(sub.ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if sub.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
rp.stats.ConsumerErrors.Add(1)
|
||||
logger.Warn("Failed to read from consumer group: %v", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
// Acknowledge message anyway to prevent redelivery
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if event matches pattern
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- &event:
|
||||
rp.stats.EventsConsumed.Add(1)
|
||||
// Acknowledge message
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
case <-sub.ctx.Done():
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Acknowledge message even if it doesn't match pattern
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureConsumerGroup creates the consumer group if it doesn't exist
|
||||
func (rp *RedisProvider) ensureConsumerGroup(ctx context.Context) error {
|
||||
// Try to create the stream and consumer group
|
||||
// MKSTREAM creates the stream if it doesn't exist
|
||||
err := rp.client.XGroupCreateMkStream(ctx, rp.streamName, rp.consumerGroup, "0").Err()
|
||||
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (rp *RedisProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
140
pkg/eventbroker/subscription.go
Normal file
140
pkg/eventbroker/subscription.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// SubscriptionID uniquely identifies a subscription
|
||||
type SubscriptionID string
|
||||
|
||||
// subscription represents a single subscription with its handler and pattern
|
||||
type subscription struct {
|
||||
id SubscriptionID
|
||||
pattern string
|
||||
handler EventHandler
|
||||
}
|
||||
|
||||
// subscriptionManager manages event subscriptions and pattern matching
|
||||
type subscriptionManager struct {
|
||||
mu sync.RWMutex
|
||||
subscriptions map[SubscriptionID]*subscription
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
// newSubscriptionManager creates a new subscription manager
|
||||
func newSubscriptionManager() *subscriptionManager {
|
||||
return &subscriptionManager{
|
||||
subscriptions: make(map[SubscriptionID]*subscription),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription
|
||||
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
if pattern == "" {
|
||||
return "", fmt.Errorf("pattern cannot be empty")
|
||||
}
|
||||
if handler == nil {
|
||||
return "", fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
|
||||
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.subscriptions[id] = &subscription{
|
||||
id: id,
|
||||
pattern: pattern,
|
||||
handler: handler,
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if _, exists := sm.subscriptions[id]; !exists {
|
||||
return fmt.Errorf("subscription not found: %s", id)
|
||||
}
|
||||
|
||||
delete(sm.subscriptions, id)
|
||||
logger.Info("Unsubscribed: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMatching returns all handlers that match the event type
|
||||
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var handlers []EventHandler
|
||||
for _, sub := range sm.subscriptions {
|
||||
if matchPattern(sub.pattern, eventType) {
|
||||
handlers = append(handlers, sub.handler)
|
||||
}
|
||||
}
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
// Count returns the number of active subscriptions
|
||||
func (sm *subscriptionManager) Count() int {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return len(sm.subscriptions)
|
||||
}
|
||||
|
||||
// Clear removes all subscriptions
|
||||
func (sm *subscriptionManager) Clear() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.subscriptions = make(map[SubscriptionID]*subscription)
|
||||
logger.Info("Cleared all subscriptions")
|
||||
}
|
||||
|
||||
// matchPattern implements glob-style pattern matching for event types
|
||||
// Patterns:
|
||||
// - "*" matches any single segment
|
||||
// - "a.b.c" matches exactly "a.b.c"
|
||||
// - "a.*.c" matches "a.anything.c"
|
||||
// - "a.b.*" matches any operation on a.b
|
||||
// - "*" matches everything
|
||||
//
|
||||
// Event type format: schema.entity.operation (e.g., "public.users.create")
|
||||
func matchPattern(pattern, eventType string) bool {
|
||||
// Wildcard matches everything
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Exact match
|
||||
if pattern == eventType {
|
||||
return true
|
||||
}
|
||||
|
||||
// Split pattern and event type by dots
|
||||
patternParts := strings.Split(pattern, ".")
|
||||
eventParts := strings.Split(eventType, ".")
|
||||
|
||||
// Different number of parts can only match if pattern has wildcards
|
||||
if len(patternParts) != len(eventParts) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Match each part
|
||||
for i := range patternParts {
|
||||
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
270
pkg/eventbroker/subscription_test.go
Normal file
270
pkg/eventbroker/subscription_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
eventType string
|
||||
expected bool
|
||||
}{
|
||||
// Exact matches
|
||||
{"public.users.create", "public.users.create", true},
|
||||
{"public.users.create", "public.users.update", false},
|
||||
|
||||
// Wildcard matches
|
||||
{"*", "public.users.create", true},
|
||||
{"*", "anything", true},
|
||||
{"public.*", "public.users", true},
|
||||
{"public.*", "public.users.create", false}, // Different number of parts
|
||||
{"public.*", "admin.users", false},
|
||||
{"*.users.create", "public.users.create", true},
|
||||
{"*.users.create", "admin.users.create", true},
|
||||
{"*.users.create", "public.roles.create", false},
|
||||
{"public.*.create", "public.users.create", true},
|
||||
{"public.*.create", "public.roles.create", true},
|
||||
{"public.*.create", "public.users.update", false},
|
||||
|
||||
// Multiple wildcards
|
||||
{"*.*", "public.users", true},
|
||||
{"*.*", "public.users.create", false}, // Different number of parts
|
||||
{"*.*.create", "public.users.create", true},
|
||||
{"*.*.create", "admin.roles.create", true},
|
||||
{"*.*.create", "public.users.update", false},
|
||||
|
||||
// Edge cases
|
||||
{"", "", true},
|
||||
{"", "something", false},
|
||||
{"something", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.eventType)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
|
||||
tt.pattern, tt.eventType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManager(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// Create test handler
|
||||
called := false
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test Subscribe
|
||||
id, err := manager.Subscribe("public.users.*", handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("Expected non-empty subscription ID")
|
||||
}
|
||||
|
||||
// Test GetMatching
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 1 {
|
||||
t.Fatalf("Expected 1 handler, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Test handler execution
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
if err := handlers[0].Handle(context.Background(), event); err != nil {
|
||||
t.Fatalf("Handler execution failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
|
||||
// Test Count
|
||||
if manager.Count() != 1 {
|
||||
t.Errorf("Expected count 1, got %d", manager.Count())
|
||||
}
|
||||
|
||||
// Test Unsubscribe
|
||||
if err := manager.Unsubscribe(id); err != nil {
|
||||
t.Fatalf("Unsubscribe failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify unsubscribed
|
||||
handlers = manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 0 {
|
||||
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
|
||||
}
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
called1 := false
|
||||
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called1 = true
|
||||
return nil
|
||||
})
|
||||
|
||||
called2 := false
|
||||
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called2 = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe multiple handlers
|
||||
id1, _ := manager.Subscribe("public.users.*", handler1)
|
||||
id2, _ := manager.Subscribe("*.users.*", handler2)
|
||||
|
||||
// Both should match
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 2 {
|
||||
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
for _, h := range handlers {
|
||||
h.Handle(context.Background(), event)
|
||||
}
|
||||
|
||||
if !called1 || !called2 {
|
||||
t.Error("Expected both handlers to be called")
|
||||
}
|
||||
|
||||
// Unsubscribe one
|
||||
manager.Unsubscribe(id1)
|
||||
handlers = manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 1 {
|
||||
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Unsubscribe remaining
|
||||
manager.Unsubscribe(id2)
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerConcurrency(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe and unsubscribe concurrently
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
id, _ := manager.Subscribe("test.*", handler)
|
||||
manager.GetMatching("test.event")
|
||||
manager.Unsubscribe(id)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Should have no subscriptions left
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// Try to unsubscribe a non-existent ID
|
||||
err := manager.Unsubscribe("non-existent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when unsubscribing non-existent ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionIDGeneration(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe multiple times and ensure unique IDs
|
||||
ids := make(map[SubscriptionID]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id, _ := manager.Subscribe("test.*", handler)
|
||||
if ids[id] {
|
||||
t.Fatalf("Duplicate subscription ID: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventHandlerFunc(t *testing.T) {
|
||||
called := false
|
||||
var receivedEvent *Event
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
receivedEvent = event
|
||||
return nil
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
err := handler.Handle(context.Background(), event)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
if receivedEvent != event {
|
||||
t.Error("Expected to receive the same event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerPatternPriority(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// More specific patterns should still match
|
||||
specificCalled := false
|
||||
genericCalled := false
|
||||
|
||||
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
specificCalled = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
genericCalled = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 2 {
|
||||
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
for _, h := range handlers {
|
||||
h.Handle(context.Background(), event)
|
||||
}
|
||||
|
||||
if !specificCalled || !genericCalled {
|
||||
t.Error("Expected both specific and generic handlers to be called")
|
||||
}
|
||||
}
|
||||
141
pkg/eventbroker/worker_pool.go
Normal file
141
pkg/eventbroker/worker_pool.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// workerPool manages a pool of workers for async event processing
|
||||
type workerPool struct {
|
||||
workerCount int
|
||||
bufferSize int
|
||||
eventQueue chan *Event
|
||||
processor func(context.Context, *Event) error
|
||||
|
||||
activeWorkers atomic.Int32
|
||||
isRunning atomic.Bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// newWorkerPool creates a new worker pool
|
||||
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
|
||||
return &workerPool{
|
||||
workerCount: workerCount,
|
||||
bufferSize: bufferSize,
|
||||
eventQueue: make(chan *Event, bufferSize),
|
||||
processor: processor,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the worker pool
|
||||
func (wp *workerPool) Start() {
|
||||
if wp.isRunning.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
wp.isRunning.Store(true)
|
||||
|
||||
// Start workers
|
||||
for i := 0; i < wp.workerCount; i++ {
|
||||
wp.wg.Add(1)
|
||||
go wp.worker(i)
|
||||
}
|
||||
|
||||
logger.Info("Worker pool started with %d workers", wp.workerCount)
|
||||
}
|
||||
|
||||
// Stop stops the worker pool gracefully
|
||||
func (wp *workerPool) Stop(ctx context.Context) error {
|
||||
if !wp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
wp.isRunning.Store(false)
|
||||
|
||||
// Close event queue to signal workers
|
||||
close(wp.eventQueue)
|
||||
|
||||
// Wait for workers to finish with context timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wp.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Info("Worker pool stopped gracefully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Worker pool stop timed out, some events may be lost")
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Submit submits an event to the queue
|
||||
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
|
||||
if !wp.isRunning.Load() {
|
||||
return ErrWorkerPoolStopped
|
||||
}
|
||||
|
||||
select {
|
||||
case wp.eventQueue <- event:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return ErrQueueFull
|
||||
}
|
||||
}
|
||||
|
||||
// worker is a worker goroutine that processes events from the queue
|
||||
func (wp *workerPool) worker(id int) {
|
||||
defer wp.wg.Done()
|
||||
|
||||
logger.Debug("Worker %d started", id)
|
||||
|
||||
for event := range wp.eventQueue {
|
||||
wp.activeWorkers.Add(1)
|
||||
|
||||
// Process event with background context (detached from original request)
|
||||
ctx := context.Background()
|
||||
if err := wp.processor(ctx, event); err != nil {
|
||||
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
|
||||
}
|
||||
|
||||
wp.activeWorkers.Add(-1)
|
||||
}
|
||||
|
||||
logger.Debug("Worker %d stopped", id)
|
||||
}
|
||||
|
||||
// QueueSize returns the current queue size
|
||||
func (wp *workerPool) QueueSize() int {
|
||||
return len(wp.eventQueue)
|
||||
}
|
||||
|
||||
// ActiveWorkers returns the number of currently active workers
|
||||
func (wp *workerPool) ActiveWorkers() int {
|
||||
return int(wp.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
|
||||
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
|
||||
)
|
||||
|
||||
// BrokerError represents an error from the event broker
|
||||
type BrokerError struct {
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *BrokerError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
@@ -20,8 +20,23 @@ import (
|
||||
|
||||
// Handler handles function-based SQL API requests
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
hooks *HookRegistry
|
||||
db common.Database
|
||||
hooks *HookRegistry
|
||||
variablesCallback func(r *http.Request) map[string]interface{}
|
||||
}
|
||||
|
||||
type SqlQueryOptions struct {
|
||||
NoCount bool
|
||||
BlankParams bool
|
||||
AllowFilter bool
|
||||
}
|
||||
|
||||
func NewSqlQueryOptions() SqlQueryOptions {
|
||||
return SqlQueryOptions{
|
||||
NoCount: false,
|
||||
BlankParams: true,
|
||||
AllowFilter: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new function API handler
|
||||
@@ -38,6 +53,14 @@ func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
func (h *Handler) SetVariablesCallback(callback func(r *http.Request) map[string]interface{}) {
|
||||
h.variablesCallback = callback
|
||||
}
|
||||
|
||||
func (h *Handler) GetVariablesCallback() func(r *http.Request) map[string]interface{} {
|
||||
return h.variablesCallback
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry for this handler
|
||||
// Use this to register custom hooks for operations
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
@@ -48,7 +71,7 @@ func (h *Handler) Hooks() *HookRegistry {
|
||||
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
||||
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
|
||||
func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -61,7 +84,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
// Create local copy to avoid modifying the captured parameter across requests
|
||||
sqlquery := sqlquery
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var dbobjlist []map[string]interface{}
|
||||
@@ -70,6 +93,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
|
||||
complexAPI := false
|
||||
|
||||
// Get user context from security package
|
||||
@@ -93,9 +117,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
MetaInfo: metainfo,
|
||||
PropQry: propQry,
|
||||
UserContext: userCtx,
|
||||
NoCount: pNoCount,
|
||||
BlankParams: pBlankparms,
|
||||
AllowFilter: pAllowFilter,
|
||||
NoCount: options.NoCount,
|
||||
BlankParams: options.BlankParams,
|
||||
AllowFilter: options.AllowFilter,
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
@@ -131,13 +155,13 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
complexAPI = reqParams.ComplexAPI
|
||||
|
||||
// Merge query string parameters
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry)
|
||||
|
||||
// Merge header parameters
|
||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||
|
||||
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
||||
if !pAllowFilter {
|
||||
if !options.AllowFilter {
|
||||
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
||||
}
|
||||
|
||||
@@ -149,7 +173,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
|
||||
// Override pNoCount if skipcount is specified
|
||||
if reqParams.SkipCount {
|
||||
pNoCount = true
|
||||
options.NoCount = true
|
||||
}
|
||||
|
||||
// Build metainfo
|
||||
@@ -164,7 +188,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
if options.BlankParams {
|
||||
for _, kw := range inputvars {
|
||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||
@@ -205,7 +229,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
||||
}
|
||||
|
||||
if !pNoCount {
|
||||
if !options.NoCount {
|
||||
if limit > 0 && offset > 0 {
|
||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
||||
} else if limit > 0 {
|
||||
@@ -244,7 +268,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
// Normalize PostgreSQL types for proper JSON marshaling
|
||||
dbobjlist = normalizePostgresTypesList(rows)
|
||||
|
||||
if pNoCount {
|
||||
if options.NoCount {
|
||||
total = int64(len(dbobjlist))
|
||||
}
|
||||
|
||||
@@ -386,7 +410,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
}
|
||||
|
||||
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
||||
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -399,13 +423,14 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
// Create local copy to avoid modifying the captured parameter across requests
|
||||
sqlquery := sqlquery
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
propQry := make(map[string]string)
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
|
||||
dbobj := make(map[string]interface{})
|
||||
complexAPI := false
|
||||
|
||||
@@ -430,7 +455,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
MetaInfo: metainfo,
|
||||
PropQry: propQry,
|
||||
UserContext: userCtx,
|
||||
BlankParams: pBlankparms,
|
||||
BlankParams: options.BlankParams,
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
@@ -497,17 +522,24 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
if strings.HasPrefix(kLower, "x-fieldfilter-") {
|
||||
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
|
||||
if strings.Contains(strings.ToLower(sqlquery), colname) {
|
||||
if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
switch val {
|
||||
case "0":
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
|
||||
case "":
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
|
||||
default:
|
||||
if IsNumeric(val) {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
if options.BlankParams {
|
||||
for _, kw := range inputvars {
|
||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||
@@ -631,8 +663,21 @@ func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) st
|
||||
|
||||
// mergePathParams merges URL path parameters into the SQL query
|
||||
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
||||
// Note: Path parameters would typically come from a router like gorilla/mux
|
||||
// For now, this is a placeholder for path parameter extraction
|
||||
|
||||
if h.GetVariablesCallback() != nil {
|
||||
pathVars := h.GetVariablesCallback()(r)
|
||||
for k, v := range pathVars {
|
||||
kword := fmt.Sprintf("[%s]", k)
|
||||
if strings.Contains(sqlquery, kword) {
|
||||
// Sanitize the value before replacing
|
||||
vStr := fmt.Sprintf("%v", v)
|
||||
sanitized := ValidSQL(vStr, "colvalue")
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||
}
|
||||
variables[k] = v
|
||||
|
||||
}
|
||||
}
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
@@ -655,7 +700,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
||||
// Replace in SQL if placeholder exists
|
||||
if strings.Contains(sqlquery, kword) && len(val) > 0 {
|
||||
if strings.HasPrefix(parmk, "p-") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
||||
// Sanitize the parameter value before replacing
|
||||
sanitized := ValidSQL(val, "colvalue")
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,15 +714,36 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
||||
// Apply filters if allowed
|
||||
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
||||
if len(parmv) > 1 {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
|
||||
// Sanitize each value in the IN clause with appropriate quoting
|
||||
sanitizedValues := make([]string, len(parmv))
|
||||
for i, v := range parmv {
|
||||
if IsNumeric(v) {
|
||||
// Numeric values don't need quotes
|
||||
sanitizedValues[i] = ValidSQL(v, "colvalue")
|
||||
} else {
|
||||
// String values need quotes
|
||||
sanitized := ValidSQL(v, "colvalue")
|
||||
sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized)
|
||||
}
|
||||
}
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ",")))
|
||||
} else {
|
||||
if strings.Contains(val, "match=") {
|
||||
colval := strings.ReplaceAll(val, "match=", "")
|
||||
// Escape single quotes and backslashes for LIKE patterns
|
||||
// But don't escape wildcards % and _ which are intentional
|
||||
colval = strings.ReplaceAll(colval, "\\", "\\\\")
|
||||
colval = strings.ReplaceAll(colval, "'", "''")
|
||||
if colval != "*" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
||||
}
|
||||
} else if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||
// For empty/zero values, treat as literal 0 or empty string with quotes
|
||||
if val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
|
||||
}
|
||||
} else {
|
||||
if IsNumeric(val) {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||
@@ -708,16 +776,25 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
||||
|
||||
kword := fmt.Sprintf("[%s]", k)
|
||||
if strings.Contains(sqlquery, kword) {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
||||
// Sanitize the header value before replacing
|
||||
sanitized := ValidSQL(val, "colvalue")
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||
}
|
||||
|
||||
// Handle special headers
|
||||
if strings.Contains(k, "x-fieldfilter-") {
|
||||
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
|
||||
if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
switch val {
|
||||
case "0":
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
|
||||
case "":
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
|
||||
default:
|
||||
if IsNumeric(val) {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -747,12 +824,15 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
||||
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
||||
if strings.Contains(sqlquery, "[p_meta_default]") {
|
||||
data, _ := json.Marshal(metainfo)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
||||
dataStr := strings.ReplaceAll(string(data), "$META$", "/*META*/")
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("$META$%s$META$::jsonb", dataStr))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[json_variables]") {
|
||||
data, _ := json.Marshal(variables)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
||||
dataStr := strings.ReplaceAll(string(data), "$VAR$", "/*VAR*/")
|
||||
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("$VAR$%s$VAR$::jsonb", dataStr))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[rid_user]") {
|
||||
@@ -760,7 +840,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[user]") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("$USR$%s$USR$", strings.ReplaceAll(userCtx.UserName, "$USR$", "/*USR*/")))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[rid_session]") {
|
||||
@@ -771,7 +851,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[method]") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[method]", fmt.Sprintf("$M$%s$M$", strings.ReplaceAll(r.Method, "$M$", "/*M*/")))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[post_body]") {
|
||||
@@ -784,7 +864,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
||||
}
|
||||
}
|
||||
}
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("$PBODY$%s$PBODY$", strings.ReplaceAll(bodystr, "$PBODY$", "/*PBODY*/")))
|
||||
}
|
||||
|
||||
return sqlquery
|
||||
@@ -824,19 +904,23 @@ func ValidSQL(input, mode string) string {
|
||||
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
|
||||
return reg.ReplaceAllString(input, "")
|
||||
case "colvalue":
|
||||
// For column values, escape single quotes
|
||||
return strings.ReplaceAll(input, "'", "''")
|
||||
// For column values, escape single quotes and backslashes
|
||||
// Note: Backslashes must be escaped first, then single quotes
|
||||
result := strings.ReplaceAll(input, "\\", "\\\\")
|
||||
result = strings.ReplaceAll(result, "'", "''")
|
||||
return result
|
||||
case "select":
|
||||
// For SELECT clauses, be more permissive but still safe
|
||||
// Remove semicolons and common SQL injection patterns
|
||||
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
|
||||
result := input
|
||||
for _, d := range dangerous {
|
||||
result = strings.ReplaceAll(result, d, "")
|
||||
result = strings.ReplaceAll(result, strings.ToLower(d), "")
|
||||
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
|
||||
// Remove semicolons and common SQL injection patterns (case-insensitive)
|
||||
dangerous := []string{
|
||||
";", "--", "/\\*", "\\*/", "xp_", "sp_",
|
||||
"drop ", "delete ", "truncate ", "update ", "insert ",
|
||||
"exec ", "execute ", "union ", "declare ", "alter ", "create ",
|
||||
}
|
||||
return result
|
||||
// Build a single regex pattern with all dangerous keywords
|
||||
pattern := "(?i)(" + strings.Join(dangerous, "|") + ")"
|
||||
re := regexp.MustCompile(pattern)
|
||||
return re.ReplaceAllString(input, "")
|
||||
default:
|
||||
return input
|
||||
}
|
||||
|
||||
@@ -70,6 +70,10 @@ func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
@@ -532,7 +536,7 @@ func TestSqlQuery(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
|
||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
@@ -655,7 +659,7 @@ func TestSqlQueryList(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
|
||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
|
||||
@@ -576,7 +576,7 @@ func TestHookIntegrationWithHandler(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if !hookCalled {
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||
)
|
||||
|
||||
var Logger *zap.SugaredLogger
|
||||
var errorTracker errortracking.Provider
|
||||
|
||||
func Init(dev bool) {
|
||||
|
||||
@@ -49,6 +53,50 @@ func UpdateLogger(config *zap.Config) {
|
||||
Info("ResolveSpec Logger initialized")
|
||||
}
|
||||
|
||||
// InitErrorTracking initializes the error tracking provider
|
||||
func InitErrorTracking(provider errortracking.Provider) {
|
||||
errorTracker = provider
|
||||
if errorTracker != nil {
|
||||
Info("Error tracking initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// GetErrorTracker returns the current error tracking provider
|
||||
func GetErrorTracker() errortracking.Provider {
|
||||
return errorTracker
|
||||
}
|
||||
|
||||
// CloseErrorTracking flushes and closes the error tracking provider
|
||||
func CloseErrorTracking() error {
|
||||
if errorTracker != nil {
|
||||
errorTracker.Flush(5)
|
||||
return errorTracker.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractContext attempts to find a context.Context in the given arguments.
|
||||
// It returns the found context (or context.Background() if not found) and
|
||||
// the remaining arguments without the context.
|
||||
func extractContext(args ...interface{}) (ctx context.Context, filteredArgs []interface{}) {
|
||||
ctx = context.Background()
|
||||
var newArgs []interface{}
|
||||
found := false
|
||||
|
||||
for _, arg := range args {
|
||||
if c, ok := arg.(context.Context); ok {
|
||||
if !found {
|
||||
ctx = c
|
||||
found = true
|
||||
}
|
||||
// Ignore any additional context.Context arguments after the first one.
|
||||
continue
|
||||
}
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
return ctx, newArgs
|
||||
}
|
||||
|
||||
func Info(template string, args ...interface{}) {
|
||||
if Logger == nil {
|
||||
log.Printf(template, args...)
|
||||
@@ -58,19 +106,37 @@ func Info(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
func Warn(template string, args ...interface{}) {
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf(template, args...)
|
||||
return
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
Logger.Warnw(message, "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
func Error(template string, args ...interface{}) {
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf(template, args...)
|
||||
return
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
Logger.Errorw(message, "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
func Debug(template string, args ...interface{}) {
|
||||
@@ -82,35 +148,41 @@ func Debug(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanicCallback(location string, cb func(err any)) {
|
||||
if err := recover(); err != nil {
|
||||
// callstack := debug.Stack()
|
||||
// Returns a function that should be deferred to catch panics
|
||||
// Example usage: defer CatchPanicCallback("MyFunction", func(err any) { /* cleanup */ })()
|
||||
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) func() {
|
||||
ctx, _ := extractContext(args...)
|
||||
return func() {
|
||||
if err := recover(); err != nil {
|
||||
callstack := debug.Stack()
|
||||
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err)
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
|
||||
// push to sentry
|
||||
// hub := sentry.CurrentHub()
|
||||
// if hub != nil {
|
||||
// evtID := hub.Recover(err)
|
||||
// if evtID != nil {
|
||||
// sentry.Flush(time.Second * 2)
|
||||
// }
|
||||
// }
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
|
||||
"location": location,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
// Returns a function that should be deferred to catch panics
|
||||
// Example usage: defer CatchPanic("MyFunction")()
|
||||
func CatchPanic(location string, args ...interface{}) func() {
|
||||
return CatchPanicCallback(location, nil, args...)
|
||||
}
|
||||
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
@@ -122,8 +194,18 @@ func CatchPanic(location string) {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
func HandlePanic(methodName string, r any, args ...interface{}) error {
|
||||
ctx, _ := extractContext(args...)
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
|
||||
"method": methodName,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
|
||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ A pluggable metrics collection system with Prometheus implementation.
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
|
||||
// Initialize Prometheus provider
|
||||
provider := metrics.NewPrometheusProvider()
|
||||
// Initialize Prometheus provider with default config
|
||||
provider := metrics.NewPrometheusProvider(nil)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Apply middleware to your router
|
||||
@@ -18,6 +18,59 @@ router.Use(provider.Middleware)
|
||||
http.Handle("/metrics", provider.Handler())
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
You can customize the metrics provider using a configuration struct:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
|
||||
// Create custom configuration
|
||||
config := &metrics.Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
Namespace: "myapp", // Prefix all metrics with "myapp_"
|
||||
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
|
||||
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
|
||||
}
|
||||
|
||||
// Initialize with custom config
|
||||
provider := metrics.NewPrometheusProvider(config)
|
||||
metrics.SetProvider(provider)
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `Enabled` | `bool` | `true` | Enable/disable metrics collection |
|
||||
| `Provider` | `string` | `"prometheus"` | Metrics provider type |
|
||||
| `Namespace` | `string` | `""` | Prefix for all metric names |
|
||||
| `HTTPRequestBuckets` | `[]float64` | See below | Histogram buckets for HTTP duration (seconds) |
|
||||
| `DBQueryBuckets` | `[]float64` | See below | Histogram buckets for DB query duration (seconds) |
|
||||
|
||||
**Default HTTP Request Buckets:** `[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]`
|
||||
|
||||
**Default DB Query Buckets:** `[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]`
|
||||
|
||||
### Pushgateway Configuration (Optional)
|
||||
|
||||
For batch jobs, cron tasks, or short-lived processes, you can push metrics to Prometheus Pushgateway:
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `PushgatewayURL` | `string` | `""` | URL of Pushgateway (e.g., "http://pushgateway:9091") |
|
||||
| `PushgatewayJobName` | `string` | `"resolvespec"` | Job name for pushed metrics |
|
||||
| `PushgatewayInterval` | `int` | `0` | Auto-push interval in seconds (0 = disabled) |
|
||||
|
||||
```go
|
||||
config := &metrics.Config{
|
||||
PushgatewayURL: "http://pushgateway:9091",
|
||||
PushgatewayJobName: "batch-job",
|
||||
PushgatewayInterval: 30, // Push every 30 seconds
|
||||
}
|
||||
```
|
||||
|
||||
## Provider Interface
|
||||
|
||||
The package uses a provider interface, allowing you to plug in different metric systems:
|
||||
@@ -87,6 +140,13 @@ When using `PrometheusProvider`, the following metrics are available:
|
||||
| `cache_hits_total` | Counter | provider | Total cache hits |
|
||||
| `cache_misses_total` | Counter | provider | Total cache misses |
|
||||
| `cache_size_items` | Gauge | provider | Current cache size |
|
||||
| `events_published_total` | Counter | source, event_type | Total events published |
|
||||
| `events_processed_total` | Counter | source, event_type, status | Total events processed |
|
||||
| `event_processing_duration_seconds` | Histogram | source, event_type | Event processing duration |
|
||||
| `event_queue_size` | Gauge | - | Current event queue size |
|
||||
| `panics_total` | Counter | method | Total panics recovered |
|
||||
|
||||
**Note:** If a custom `Namespace` is configured, all metric names will be prefixed with `{namespace}_`.
|
||||
|
||||
## Prometheus Queries
|
||||
|
||||
@@ -146,8 +206,126 @@ func (c *CustomProvider) Handler() http.Handler {
|
||||
metrics.SetProvider(&CustomProvider{})
|
||||
```
|
||||
|
||||
## Pushgateway Usage
|
||||
|
||||
### Automatic Push (Batch Jobs)
|
||||
|
||||
For jobs that run periodically, use automatic pushing:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Configure with automatic pushing every 30 seconds
|
||||
config := &metrics.Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
Namespace: "batch_job",
|
||||
PushgatewayURL: "http://pushgateway:9091",
|
||||
PushgatewayJobName: "data-processor",
|
||||
PushgatewayInterval: 30, // Push every 30 seconds
|
||||
}
|
||||
|
||||
provider := metrics.NewPrometheusProvider(config)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Ensure cleanup on exit
|
||||
defer provider.StopAutoPush()
|
||||
|
||||
// Your batch job logic here
|
||||
processBatchData()
|
||||
}
|
||||
```
|
||||
|
||||
### Manual Push (Short-lived Processes)
|
||||
|
||||
For one-time jobs or when you want manual control:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Configure without automatic pushing
|
||||
config := &metrics.Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
PushgatewayURL: "http://pushgateway:9091",
|
||||
PushgatewayJobName: "migration-job",
|
||||
// PushgatewayInterval: 0 (default - no auto-push)
|
||||
}
|
||||
|
||||
provider := metrics.NewPrometheusProvider(config)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Run your job
|
||||
err := runMigration()
|
||||
|
||||
// Push metrics at the end
|
||||
if pushErr := provider.Push(); pushErr != nil {
|
||||
log.Printf("Failed to push metrics: %v", pushErr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Docker Compose with Pushgateway
|
||||
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
batch-job:
|
||||
build: .
|
||||
environment:
|
||||
PUSHGATEWAY_URL: "http://pushgateway:9091"
|
||||
|
||||
pushgateway:
|
||||
image: prom/pushgateway
|
||||
ports:
|
||||
- "9091:9091"
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
```
|
||||
|
||||
**prometheus.yml for Pushgateway:**
|
||||
|
||||
```yaml
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
# Scrape the pushgateway
|
||||
- job_name: 'pushgateway'
|
||||
honor_labels: true # Important: preserve job labels from pushed metrics
|
||||
static_configs:
|
||||
- targets: ['pushgateway:9091']
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
@@ -162,8 +340,8 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize metrics
|
||||
provider := metrics.NewPrometheusProvider()
|
||||
// Initialize metrics with default config
|
||||
provider := metrics.NewPrometheusProvider(nil)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Create router
|
||||
@@ -198,6 +376,42 @@ func getUsersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
```
|
||||
|
||||
### With Custom Configuration
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Custom metrics configuration
|
||||
metricsConfig := &metrics.Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
Namespace: "myapp",
|
||||
// Custom buckets optimized for your application
|
||||
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5, 10},
|
||||
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1},
|
||||
}
|
||||
|
||||
// Initialize with custom config
|
||||
provider := metrics.NewPrometheusProvider(metricsConfig)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.Use(provider.Middleware)
|
||||
router.Handle("/metrics", provider.Handler())
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
## Docker Compose Example
|
||||
|
||||
```yaml
|
||||
@@ -257,3 +471,8 @@ scrape_configs:
|
||||
4. **Performance**: Metrics collection is lock-free and highly performant
|
||||
- Safe for high-throughput applications
|
||||
- Minimal overhead (<1% in most cases)
|
||||
|
||||
5. **Pull vs Push**:
|
||||
- **Use Pull (default)**: Long-running services, web servers, microservices
|
||||
- **Use Push (Pushgateway)**: Batch jobs, cron tasks, short-lived processes, serverless functions
|
||||
- Pull is preferred for most applications as it allows Prometheus to detect if your service is down
|
||||
|
||||
64
pkg/metrics/config.go
Normal file
64
pkg/metrics/config.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package metrics
|
||||
|
||||
// Config holds configuration for the metrics provider
|
||||
type Config struct {
|
||||
// Enabled determines whether metrics collection is enabled
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
|
||||
// Provider specifies which metrics provider to use (prometheus, noop)
|
||||
Provider string `mapstructure:"provider"`
|
||||
|
||||
// Namespace is an optional prefix for all metric names
|
||||
Namespace string `mapstructure:"namespace"`
|
||||
|
||||
// HTTPRequestBuckets defines histogram buckets for HTTP request duration (in seconds)
|
||||
// Default: [0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]
|
||||
HTTPRequestBuckets []float64 `mapstructure:"http_request_buckets"`
|
||||
|
||||
// DBQueryBuckets defines histogram buckets for database query duration (in seconds)
|
||||
// Default: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]
|
||||
DBQueryBuckets []float64 `mapstructure:"db_query_buckets"`
|
||||
|
||||
// PushgatewayURL is the URL of the Prometheus Pushgateway (optional)
|
||||
// If set, metrics will be pushed to this gateway instead of only being scraped
|
||||
// Example: "http://pushgateway:9091"
|
||||
PushgatewayURL string `mapstructure:"pushgateway_url"`
|
||||
|
||||
// PushgatewayJobName is the job name to use when pushing metrics to Pushgateway
|
||||
// Default: "resolvespec"
|
||||
PushgatewayJobName string `mapstructure:"pushgateway_job_name"`
|
||||
|
||||
// PushgatewayInterval is the interval at which to push metrics to Pushgateway
|
||||
// Only used if PushgatewayURL is set. If 0, automatic pushing is disabled.
|
||||
// Default: 0 (no automatic pushing)
|
||||
PushgatewayInterval int `mapstructure:"pushgateway_interval"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
// HTTP requests typically take longer than DB queries
|
||||
HTTPRequestBuckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
// DB queries are usually faster
|
||||
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5},
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyDefaults fills in any missing values with defaults
|
||||
func (c *Config) ApplyDefaults() {
|
||||
if c.Provider == "" {
|
||||
c.Provider = "prometheus"
|
||||
}
|
||||
if len(c.HTTPRequestBuckets) == 0 {
|
||||
c.HTTPRequestBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
|
||||
}
|
||||
if len(c.DBQueryBuckets) == 0 {
|
||||
c.DBQueryBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5}
|
||||
}
|
||||
// Set default job name if pushgateway is configured but job name is empty
|
||||
if c.PushgatewayURL != "" && c.PushgatewayJobName == "" {
|
||||
c.PushgatewayJobName = "resolvespec"
|
||||
}
|
||||
}
|
||||
64
pkg/metrics/example_test.go
Normal file
64
pkg/metrics/example_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package metrics_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
// ExampleNewPrometheusProvider_default demonstrates using default configuration
|
||||
func ExampleNewPrometheusProvider_default() {
|
||||
// Initialize with default configuration
|
||||
provider := metrics.NewPrometheusProvider(nil)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
fmt.Println("Provider initialized with defaults")
|
||||
// Output: Provider initialized with defaults
|
||||
}
|
||||
|
||||
// ExampleNewPrometheusProvider_custom demonstrates using custom configuration
|
||||
func ExampleNewPrometheusProvider_custom() {
|
||||
// Create custom configuration
|
||||
config := &metrics.Config{
|
||||
Enabled: true,
|
||||
Provider: "prometheus",
|
||||
Namespace: "myapp",
|
||||
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
|
||||
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
|
||||
}
|
||||
|
||||
// Initialize with custom configuration
|
||||
provider := metrics.NewPrometheusProvider(config)
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
fmt.Println("Provider initialized with custom config")
|
||||
// Output: Provider initialized with custom config
|
||||
}
|
||||
|
||||
// ExampleDefaultConfig demonstrates getting default configuration
|
||||
func ExampleDefaultConfig() {
|
||||
config := metrics.DefaultConfig()
|
||||
fmt.Printf("Default provider: %s\n", config.Provider)
|
||||
fmt.Printf("Default enabled: %v\n", config.Enabled)
|
||||
// Output:
|
||||
// Default provider: prometheus
|
||||
// Default enabled: true
|
||||
}
|
||||
|
||||
// ExampleConfig_ApplyDefaults demonstrates applying defaults to partial config
|
||||
func ExampleConfig_ApplyDefaults() {
|
||||
// Create partial configuration
|
||||
config := &metrics.Config{
|
||||
Namespace: "myapp",
|
||||
// Other fields will be filled with defaults
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
config.ApplyDefaults()
|
||||
|
||||
fmt.Printf("Provider: %s\n", config.Provider)
|
||||
fmt.Printf("Namespace: %s\n", config.Namespace)
|
||||
// Output:
|
||||
// Provider: prometheus
|
||||
// Namespace: myapp
|
||||
}
|
||||
@@ -30,6 +30,18 @@ type Provider interface {
|
||||
// UpdateCacheSize updates the cache size metric
|
||||
UpdateCacheSize(provider string, size int64)
|
||||
|
||||
// RecordEventPublished records an event publication
|
||||
RecordEventPublished(source, eventType string)
|
||||
|
||||
// RecordEventProcessed records an event processing with its status
|
||||
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||
|
||||
// UpdateEventQueueSize updates the event queue size metric
|
||||
UpdateEventQueueSize(size int64)
|
||||
|
||||
// RecordPanic records a panic event
|
||||
RecordPanic(methodName string)
|
||||
|
||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||
Handler() http.Handler
|
||||
}
|
||||
@@ -59,9 +71,14 @@ func (n *NoOpProvider) IncRequestsInFlight()
|
||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
||||
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
}
|
||||
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||
func (n *NoOpProvider) RecordPanic(methodName string) {}
|
||||
func (n *NoOpProvider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/client_golang/prometheus/push"
|
||||
)
|
||||
|
||||
// PrometheusProvider implements the Provider interface using Prometheus
|
||||
@@ -20,22 +21,51 @@ type PrometheusProvider struct {
|
||||
cacheHits *prometheus.CounterVec
|
||||
cacheMisses *prometheus.CounterVec
|
||||
cacheSize *prometheus.GaugeVec
|
||||
eventPublished *prometheus.CounterVec
|
||||
eventProcessed *prometheus.CounterVec
|
||||
eventDuration *prometheus.HistogramVec
|
||||
eventQueueSize prometheus.Gauge
|
||||
panicsTotal *prometheus.CounterVec
|
||||
|
||||
// Pushgateway fields (optional)
|
||||
pushgatewayURL string
|
||||
pushgatewayJobName string
|
||||
pusher *push.Pusher
|
||||
pushTicker *time.Ticker
|
||||
pushStop chan bool
|
||||
}
|
||||
|
||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||
func NewPrometheusProvider() *PrometheusProvider {
|
||||
return &PrometheusProvider{
|
||||
// If cfg is nil, default configuration will be used
|
||||
func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
|
||||
// Use default config if none provided
|
||||
if cfg == nil {
|
||||
cfg = DefaultConfig()
|
||||
} else {
|
||||
// Apply defaults for any missing values
|
||||
cfg.ApplyDefaults()
|
||||
}
|
||||
|
||||
// Helper to add namespace prefix if configured
|
||||
metricName := func(name string) string {
|
||||
if cfg.Namespace != "" {
|
||||
return cfg.Namespace + "_" + name
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
p := &PrometheusProvider{
|
||||
requestDuration: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "http_request_duration_seconds",
|
||||
Name: metricName("http_request_duration_seconds"),
|
||||
Help: "HTTP request duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
Buckets: cfg.HTTPRequestBuckets,
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
),
|
||||
requestTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "http_requests_total",
|
||||
Name: metricName("http_requests_total"),
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
@@ -43,47 +73,100 @@ func NewPrometheusProvider() *PrometheusProvider {
|
||||
|
||||
requestsInFlight: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "http_requests_in_flight",
|
||||
Name: metricName("http_requests_in_flight"),
|
||||
Help: "Current number of HTTP requests being processed",
|
||||
},
|
||||
),
|
||||
dbQueryDuration: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "db_query_duration_seconds",
|
||||
Name: metricName("db_query_duration_seconds"),
|
||||
Help: "Database query duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
Buckets: cfg.DBQueryBuckets,
|
||||
},
|
||||
[]string{"operation", "table"},
|
||||
),
|
||||
dbQueryTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "db_queries_total",
|
||||
Name: metricName("db_queries_total"),
|
||||
Help: "Total number of database queries",
|
||||
},
|
||||
[]string{"operation", "table", "status"},
|
||||
),
|
||||
cacheHits: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_hits_total",
|
||||
Name: metricName("cache_hits_total"),
|
||||
Help: "Total number of cache hits",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
cacheMisses: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_misses_total",
|
||||
Name: metricName("cache_misses_total"),
|
||||
Help: "Total number of cache misses",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
cacheSize: promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "cache_size_items",
|
||||
Name: metricName("cache_size_items"),
|
||||
Help: "Number of items in cache",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
eventPublished: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricName("events_published_total"),
|
||||
Help: "Total number of events published",
|
||||
},
|
||||
[]string{"source", "event_type"},
|
||||
),
|
||||
eventProcessed: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricName("events_processed_total"),
|
||||
Help: "Total number of events processed",
|
||||
},
|
||||
[]string{"source", "event_type", "status"},
|
||||
),
|
||||
eventDuration: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: metricName("event_processing_duration_seconds"),
|
||||
Help: "Event processing duration in seconds",
|
||||
Buckets: cfg.DBQueryBuckets, // Events are typically fast like DB queries
|
||||
},
|
||||
[]string{"source", "event_type"},
|
||||
),
|
||||
eventQueueSize: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: metricName("event_queue_size"),
|
||||
Help: "Current number of events in queue",
|
||||
},
|
||||
),
|
||||
panicsTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricName("panics_total"),
|
||||
Help: "Total number of panics",
|
||||
},
|
||||
[]string{"method"},
|
||||
),
|
||||
|
||||
pushgatewayURL: cfg.PushgatewayURL,
|
||||
pushgatewayJobName: cfg.PushgatewayJobName,
|
||||
}
|
||||
|
||||
// Initialize pushgateway if configured
|
||||
if cfg.PushgatewayURL != "" {
|
||||
p.pusher = push.New(cfg.PushgatewayURL, cfg.PushgatewayJobName).
|
||||
Gatherer(prometheus.DefaultGatherer)
|
||||
|
||||
// Start automatic pushing if interval is configured
|
||||
if cfg.PushgatewayInterval > 0 {
|
||||
p.pushStop = make(chan bool)
|
||||
p.pushTicker = time.NewTicker(time.Duration(cfg.PushgatewayInterval) * time.Second)
|
||||
go p.startAutoPush()
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ResponseWriter wraps http.ResponseWriter to capture status code
|
||||
@@ -145,6 +228,27 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||
}
|
||||
|
||||
// RecordEventPublished implements Provider interface
|
||||
func (p *PrometheusProvider) RecordEventPublished(source, eventType string) {
|
||||
p.eventPublished.WithLabelValues(source, eventType).Inc()
|
||||
}
|
||||
|
||||
// RecordEventProcessed implements Provider interface
|
||||
func (p *PrometheusProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
p.eventProcessed.WithLabelValues(source, eventType, status).Inc()
|
||||
p.eventDuration.WithLabelValues(source, eventType).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
// UpdateEventQueueSize implements Provider interface
|
||||
func (p *PrometheusProvider) UpdateEventQueueSize(size int64) {
|
||||
p.eventQueueSize.Set(float64(size))
|
||||
}
|
||||
|
||||
// RecordPanic implements the Provider interface
|
||||
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||
}
|
||||
|
||||
// Handler implements Provider interface
|
||||
func (p *PrometheusProvider) Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
@@ -172,3 +276,37 @@ func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
|
||||
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
|
||||
})
|
||||
}
|
||||
|
||||
// Push manually pushes metrics to the configured Pushgateway
|
||||
// Returns an error if pushing fails or if Pushgateway is not configured
|
||||
func (p *PrometheusProvider) Push() error {
|
||||
if p.pusher == nil {
|
||||
return nil // Pushgateway not configured, silently skip
|
||||
}
|
||||
return p.pusher.Push()
|
||||
}
|
||||
|
||||
// startAutoPush runs in a goroutine and periodically pushes metrics to Pushgateway
|
||||
func (p *PrometheusProvider) startAutoPush() {
|
||||
for {
|
||||
select {
|
||||
case <-p.pushTicker.C:
|
||||
if err := p.Push(); err != nil {
|
||||
// Log error but continue pushing
|
||||
// Note: In production, you might want to use a proper logger
|
||||
_ = err
|
||||
}
|
||||
case <-p.pushStop:
|
||||
p.pushTicker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopAutoPush stops the automatic push goroutine
|
||||
// This should be called when shutting down the application
|
||||
func (p *PrometheusProvider) StopAutoPush() {
|
||||
if p.pushStop != nil {
|
||||
close(p.pushStop)
|
||||
}
|
||||
}
|
||||
|
||||
33
pkg/middleware/panic.go
Normal file
33
pkg/middleware/panic.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
const panicMiddlewareMethodName = "PanicMiddleware"
|
||||
|
||||
// PanicRecovery is a middleware that recovers from panics, logs the error,
|
||||
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
|
||||
func PanicRecovery(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
// Record the panic metric
|
||||
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
|
||||
|
||||
// Log the panic and send to error tracker
|
||||
// We pass the request context so the error tracker can potentially
|
||||
// link the panic to the request trace.
|
||||
ctx := r.Context()
|
||||
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
|
||||
|
||||
// Respond with a 500 error
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
86
pkg/middleware/panic_test.go
Normal file
86
pkg/middleware/panic_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
|
||||
type mockMetricsProvider struct {
|
||||
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
|
||||
panicRecorded bool
|
||||
methodName string
|
||||
}
|
||||
|
||||
func (m *mockMetricsProvider) RecordPanic(methodName string) {
|
||||
m.panicRecorded = true
|
||||
m.methodName = methodName
|
||||
}
|
||||
|
||||
func TestPanicRecovery(t *testing.T) {
|
||||
// Initialize a mock logger to avoid actual logging output during tests
|
||||
logger.Init(true)
|
||||
|
||||
// Setup mock metrics provider
|
||||
mockProvider := &mockMetricsProvider{}
|
||||
originalProvider := metrics.GetProvider()
|
||||
metrics.SetProvider(mockProvider)
|
||||
defer metrics.SetProvider(originalProvider) // Restore original provider after test
|
||||
|
||||
// 1. Test case: A handler that panics
|
||||
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
mockProvider.methodName = ""
|
||||
|
||||
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("something went terribly wrong")
|
||||
})
|
||||
|
||||
// Create the middleware wrapping the panicking handler
|
||||
testHandler := PanicRecovery(panicHandler)
|
||||
|
||||
// Create a test request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
|
||||
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
|
||||
|
||||
// Assert that the metric was recorded
|
||||
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
|
||||
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
|
||||
})
|
||||
|
||||
// 2. Test case: A handler that does NOT panic
|
||||
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
|
||||
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
testHandler := PanicRecovery(successHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
|
||||
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
|
||||
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
|
||||
})
|
||||
}
|
||||
@@ -6,15 +6,37 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ModelRules defines the permissions and security settings for a model
|
||||
type ModelRules struct {
|
||||
CanRead bool // Whether the model can be read (GET operations)
|
||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanCreate bool // Whether the model can be created (POST operations)
|
||||
CanDelete bool // Whether the model can be deleted (DELETE operations)
|
||||
SecurityDisabled bool // Whether security checks are disabled for this model
|
||||
}
|
||||
|
||||
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
|
||||
func DefaultModelRules() ModelRules {
|
||||
return ModelRules{
|
||||
CanRead: true,
|
||||
CanUpdate: true,
|
||||
CanCreate: true,
|
||||
CanDelete: true,
|
||||
SecurityDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultModelRegistry implements ModelRegistry interface
|
||||
type DefaultModelRegistry struct {
|
||||
models map[string]interface{}
|
||||
rules map[string]ModelRules
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Global default registry instance
|
||||
var defaultRegistry = &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
rules: make(map[string]ModelRules),
|
||||
}
|
||||
|
||||
// Global list of registries (searched in order)
|
||||
@@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex
|
||||
func NewModelRegistry() *DefaultModelRegistry {
|
||||
return &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
rules: make(map[string]ModelRules),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
||||
}
|
||||
|
||||
r.models[name] = model
|
||||
// Initialize with default rules if not already set
|
||||
if _, exists := r.rules[name]; !exists {
|
||||
r.rules[name] = DefaultModelRules()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
|
||||
return r.GetModel(entity)
|
||||
}
|
||||
|
||||
// SetModelRules sets the rules for a specific model
|
||||
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
// Check if model exists
|
||||
if _, exists := r.models[name]; !exists {
|
||||
return fmt.Errorf("model %s not found", name)
|
||||
}
|
||||
|
||||
r.rules[name] = rules
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelRules retrieves the rules for a specific model
|
||||
// Returns default rules if model exists but rules are not set
|
||||
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Check if model exists
|
||||
if _, exists := r.models[name]; !exists {
|
||||
return ModelRules{}, fmt.Errorf("model %s not found", name)
|
||||
}
|
||||
|
||||
// Return rules if set, otherwise return default rules
|
||||
if rules, exists := r.rules[name]; exists {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
return DefaultModelRules(), nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model with specific rules
|
||||
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
|
||||
// First register the model
|
||||
if err := r.RegisterModel(name, model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then set the rules (we need to lock again for rules)
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.rules[name] = rules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global convenience functions using the default registry
|
||||
|
||||
// RegisterModel registers a model with the default global registry
|
||||
@@ -190,3 +265,34 @@ func GetModels() []interface{} {
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// SetModelRules sets the rules for a specific model in the default registry
|
||||
func SetModelRules(name string, rules ModelRules) error {
|
||||
return defaultRegistry.SetModelRules(name, rules)
|
||||
}
|
||||
|
||||
// GetModelRules retrieves the rules for a specific model from the default registry
|
||||
func GetModelRules(name string) (ModelRules, error) {
|
||||
return defaultRegistry.GetModelRules(name)
|
||||
}
|
||||
|
||||
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
|
||||
// Returns the first match found
|
||||
func GetModelRulesByName(name string) (ModelRules, error) {
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
for _, registry := range registries {
|
||||
if _, err := registry.GetModel(name); err == nil {
|
||||
// Model found in this registry, get its rules
|
||||
return registry.GetModelRules(name)
|
||||
}
|
||||
}
|
||||
|
||||
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model with specific rules in the default registry
|
||||
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
|
||||
return defaultRegistry.RegisterModelWithRules(name, model, rules)
|
||||
}
|
||||
|
||||
724
pkg/mqttspec/README.md
Normal file
724
pkg/mqttspec/README.md
Normal file
@@ -0,0 +1,724 @@
|
||||
# MQTTSpec - MQTT-based Database Query Framework
|
||||
|
||||
MQTTSpec is an MQTT-based database query framework that enables real-time database operations and subscriptions via MQTT protocol. It mirrors the functionality of WebSocketSpec but uses MQTT as the transport layer, making it ideal for IoT applications, mobile apps with unreliable networks, and distributed systems requiring QoS guarantees.
|
||||
|
||||
## Features
|
||||
|
||||
- **Dual Broker Support**: Embedded broker (Mochi MQTT) or external broker connection (Paho MQTT)
|
||||
- **QoS 1 (At-least-once delivery)**: Reliable message delivery for all operations
|
||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||
- **Database Agnostic**: GORM and Bun ORM support
|
||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||
- **Thread-safe**: Proper concurrency handling throughout
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/bitechdev/ResolveSpec/pkg/mqttspec
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Embedded Broker (Default)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/mqttspec"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Connect to database
|
||||
db, _ := gorm.Open(postgres.Open("postgres://..."), &gorm.Config{})
|
||||
db.AutoMigrate(&User{})
|
||||
|
||||
// Create MQTT handler with embedded broker
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Register models
|
||||
handler.Registry().RegisterModel("public.users", &User{})
|
||||
|
||||
// Start handler (starts embedded broker on localhost:1883)
|
||||
if err := handler.Start(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Handler is now listening for MQTT messages
|
||||
select {} // Keep running
|
||||
}
|
||||
```
|
||||
|
||||
### External Broker
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithExternalBroker(mqttspec.ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://mqtt.example.com:1883",
|
||||
ClientID: "mqttspec-server",
|
||||
Username: "admin",
|
||||
Password: "secret",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Port (Embedded Broker)
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithEmbeddedBroker(mqttspec.BrokerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 1884,
|
||||
}),
|
||||
)
|
||||
```
|
||||
|
||||
## Topic Structure
|
||||
|
||||
MQTTSpec uses a client-based topic hierarchy:
|
||||
|
||||
```
|
||||
spec/{client_id}/request # Client publishes requests
|
||||
spec/{client_id}/response # Server publishes responses
|
||||
spec/{client_id}/notify/{sub_id} # Server publishes notifications
|
||||
```
|
||||
|
||||
### Wildcard Subscriptions
|
||||
|
||||
- **Server**: `spec/+/request` (receives all client requests)
|
||||
- **Client**: `spec/{client_id}/response` + `spec/{client_id}/notify/+`
|
||||
|
||||
## Message Protocol
|
||||
|
||||
MQTTSpec uses the same JSON message structure as WebSocketSpec and ResolveSpec for consistency.
|
||||
|
||||
### Request Message
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-123",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
],
|
||||
"sort": [{"column": "created_at", "direction": "desc"}],
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response Message
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-123",
|
||||
"type": "response",
|
||||
"success": true,
|
||||
"data": [
|
||||
{"id": 1, "name": "John Doe", "email": "john@example.com", "status": "active"},
|
||||
{"id": 2, "name": "Jane Smith", "email": "jane@example.com", "status": "active"}
|
||||
],
|
||||
"metadata": {
|
||||
"total": 50,
|
||||
"count": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Message
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "notification",
|
||||
"operation": "create",
|
||||
"subscription_id": "sub-xyz",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 3,
|
||||
"name": "New User",
|
||||
"email": "new@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## CRUD Operations
|
||||
|
||||
### Read (Single Record)
|
||||
|
||||
**MQTT Client Publishes to**: `spec/{client_id}/request`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-1",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {"id": 1}
|
||||
}
|
||||
```
|
||||
|
||||
**Server Publishes Response to**: `spec/{client_id}/response`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-1",
|
||||
"success": true,
|
||||
"data": {"id": 1, "name": "John Doe", "email": "john@example.com"}
|
||||
}
|
||||
```
|
||||
|
||||
### Read (Multiple Records with Filtering)
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-2",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
],
|
||||
"sort": [{"column": "name", "direction": "asc"}],
|
||||
"limit": 20,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-3",
|
||||
"type": "request",
|
||||
"operation": "create",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"name": "Alice Brown",
|
||||
"email": "alice@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Update
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-4",
|
||||
"type": "request",
|
||||
"operation": "update",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 1,
|
||||
"status": "inactive"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Delete
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-5",
|
||||
"type": "request",
|
||||
"operation": "delete",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {"id": 1}
|
||||
}
|
||||
```
|
||||
|
||||
## Real-time Subscriptions
|
||||
|
||||
### Subscribe to Entity Changes
|
||||
|
||||
**Client Publishes to**: `spec/{client_id}/request`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-6",
|
||||
"type": "subscription",
|
||||
"operation": "subscribe",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Server Response** (published to `spec/{client_id}/response`):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-6",
|
||||
"success": true,
|
||||
"data": {
|
||||
"subscription_id": "sub-abc123",
|
||||
"notify_topic": "spec/{client_id}/notify/sub-abc123"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Client Then Subscribes** to MQTT topic: `spec/{client_id}/notify/sub-abc123`
|
||||
|
||||
### Receiving Notifications
|
||||
|
||||
When any client creates/updates/deletes a user matching the subscription filters, the subscriber receives:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "notification",
|
||||
"operation": "create",
|
||||
"subscription_id": "sub-abc123",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 10,
|
||||
"name": "New User",
|
||||
"email": "newuser@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Unsubscribe
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-7",
|
||||
"type": "subscription",
|
||||
"operation": "unsubscribe",
|
||||
"data": {
|
||||
"subscription_id": "sub-abc123"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
|
||||
### Hook Types
|
||||
|
||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||
- `BeforeRead` / `AfterRead` - Read operations
|
||||
- `BeforeCreate` / `AfterCreate` - Create operations
|
||||
- `BeforeUpdate` / `AfterUpdate` - Update operations
|
||||
- `BeforeDelete` / `AfterDelete` - Delete operations
|
||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||
|
||||
### Authentication Example (JWT)
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.BeforeConnect, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
|
||||
// MQTT username contains JWT token
|
||||
token := client.Username
|
||||
claims, err := jwt.Validate(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
||||
// Store user info in client metadata for later use
|
||||
client.SetMetadata("user_id", claims.UserID)
|
||||
client.SetMetadata("tenant_id", claims.TenantID)
|
||||
client.SetMetadata("roles", claims.Roles)
|
||||
|
||||
logger.Info("Client authenticated: user_id=%d, tenant=%s", claims.UserID, claims.TenantID)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Multi-tenancy Example
|
||||
|
||||
```go
|
||||
// Auto-inject tenant filter for all read operations
|
||||
handler.Hooks().Register(mqttspec.BeforeRead, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
tenantID, _ := client.GetMetadata("tenant_id")
|
||||
|
||||
// Add tenant filter to ensure users only see their own data
|
||||
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
|
||||
Column: "tenant_id",
|
||||
Operator: "eq",
|
||||
Value: tenantID,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Auto-set tenant_id for all create operations
|
||||
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
tenantID, _ := client.GetMetadata("tenant_id")
|
||||
|
||||
// Inject tenant_id into new records
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["tenant_id"] = tenantID
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Role-based Access Control (RBAC)
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.BeforeDelete, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
roles, _ := client.GetMetadata("roles")
|
||||
|
||||
roleList := roles.([]string)
|
||||
hasAdminRole := false
|
||||
for _, role := range roleList {
|
||||
if role == "admin" {
|
||||
hasAdminRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAdminRole {
|
||||
return fmt.Errorf("permission denied: delete requires admin role")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Audit Logging Example
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.AfterCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
userID, _ := client.GetMetadata("user_id")
|
||||
|
||||
logger.Info("Audit: user %d created %s.%s record: %+v",
|
||||
userID, ctx.Schema, ctx.Entity, ctx.Result)
|
||||
|
||||
// Could also write to audit log table
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
## Client Examples
|
||||
|
||||
### JavaScript (MQTT.js)
|
||||
|
||||
```javascript
|
||||
const mqtt = require('mqtt');
|
||||
|
||||
// Connect to MQTT broker
|
||||
const client = mqtt.connect('mqtt://localhost:1883', {
|
||||
clientId: 'client-abc123',
|
||||
username: 'your-jwt-token',
|
||||
password: '', // JWT in username, password can be empty
|
||||
});
|
||||
|
||||
client.on('connect', () => {
|
||||
console.log('Connected to MQTT broker');
|
||||
|
||||
// Subscribe to responses
|
||||
client.subscribe('spec/client-abc123/response');
|
||||
|
||||
// Read users
|
||||
const readMsg = {
|
||||
id: 'msg-1',
|
||||
type: 'request',
|
||||
operation: 'read',
|
||||
schema: 'public',
|
||||
entity: 'users',
|
||||
options: {
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
client.publish('spec/client-abc123/request', JSON.stringify(readMsg));
|
||||
});
|
||||
|
||||
client.on('message', (topic, payload) => {
|
||||
const message = JSON.parse(payload.toString());
|
||||
console.log('Received:', message);
|
||||
|
||||
if (message.type === 'response') {
|
||||
console.log('Response data:', message.data);
|
||||
} else if (message.type === 'notification') {
|
||||
console.log('Notification:', message.operation, message.data);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
### Python (paho-mqtt)
|
||||
|
||||
```python
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
|
||||
client_id = 'client-python-123'
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
print(f"Connected with result code {rc}")
|
||||
|
||||
# Subscribe to responses
|
||||
client.subscribe(f"spec/{client_id}/response")
|
||||
|
||||
# Create a user
|
||||
create_msg = {
|
||||
'id': 'msg-create-1',
|
||||
'type': 'request',
|
||||
'operation': 'create',
|
||||
'schema': 'public',
|
||||
'entity': 'users',
|
||||
'data': {
|
||||
'name': 'Python User',
|
||||
'email': 'python@example.com',
|
||||
'status': 'active'
|
||||
}
|
||||
}
|
||||
|
||||
client.publish(f"spec/{client_id}/request", json.dumps(create_msg))
|
||||
|
||||
def on_message(client, userdata, msg):
|
||||
message = json.loads(msg.payload.decode())
|
||||
print(f"Received on {msg.topic}: {message}")
|
||||
|
||||
client = mqtt.Client(client_id=client_id)
|
||||
client.username_pw_set('your-jwt-token', '')
|
||||
client.on_connect = on_connect
|
||||
client.on_message = on_message
|
||||
|
||||
client.connect('localhost', 1883, 60)
|
||||
client.loop_forever()
|
||||
```
|
||||
|
||||
### Go (paho.mqtt.golang)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
)
|
||||
|
||||
func main() {
|
||||
clientID := "client-go-123"
|
||||
|
||||
opts := mqtt.NewClientOptions()
|
||||
opts.AddBroker("tcp://localhost:1883")
|
||||
opts.SetClientID(clientID)
|
||||
opts.SetUsername("your-jwt-token")
|
||||
opts.SetPassword("")
|
||||
|
||||
opts.SetDefaultPublishHandler(func(client mqtt.Client, msg mqtt.Message) {
|
||||
var message map[string]interface{}
|
||||
json.Unmarshal(msg.Payload(), &message)
|
||||
fmt.Printf("Received on %s: %+v\n", msg.Topic(), message)
|
||||
})
|
||||
|
||||
opts.OnConnect = func(client mqtt.Client) {
|
||||
fmt.Println("Connected to MQTT broker")
|
||||
|
||||
// Subscribe to responses
|
||||
client.Subscribe(fmt.Sprintf("spec/%s/response", clientID), 1, nil)
|
||||
|
||||
// Read users
|
||||
readMsg := map[string]interface{}{
|
||||
"id": "msg-1",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": map[string]interface{}{
|
||||
"filters": []map[string]interface{}{
|
||||
{"column": "status", "operator": "eq", "value": "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(readMsg)
|
||||
client.Publish(fmt.Sprintf("spec/%s/request", clientID), 1, false, payload)
|
||||
}
|
||||
|
||||
client := mqtt.NewClient(opts)
|
||||
if token := client.Connect(); token.Wait() && token.Error() != nil {
|
||||
panic(token.Error())
|
||||
}
|
||||
|
||||
// Keep running
|
||||
select {}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### BrokerConfig (Embedded Broker)
|
||||
|
||||
```go
|
||||
type BrokerConfig struct {
|
||||
Host string // Default: "localhost"
|
||||
Port int // Default: 1883
|
||||
EnableWebSocket bool // Enable WebSocket listener
|
||||
WSPort int // WebSocket port (default: 1884)
|
||||
MaxConnections int // Max concurrent connections
|
||||
KeepAlive time.Duration // MQTT keep-alive interval
|
||||
EnableAuth bool // Enable authentication
|
||||
}
|
||||
```
|
||||
|
||||
### ExternalBrokerConfig
|
||||
|
||||
```go
|
||||
type ExternalBrokerConfig struct {
|
||||
BrokerURL string // MQTT broker URL (tcp://host:port)
|
||||
ClientID string // MQTT client ID
|
||||
Username string // MQTT username
|
||||
Password string // MQTT password
|
||||
CleanSession bool // Clean session flag
|
||||
KeepAlive time.Duration // Keep-alive interval
|
||||
ConnectTimeout time.Duration // Connection timeout
|
||||
ReconnectDelay time.Duration // Auto-reconnect delay
|
||||
MaxReconnect int // Max reconnect attempts
|
||||
TLSConfig *tls.Config // TLS configuration
|
||||
}
|
||||
```
|
||||
|
||||
### QoS Configuration
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithQoS(1, 1, 1), // Request, Response, Notification
|
||||
)
|
||||
```
|
||||
|
||||
### Topic Prefix
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithTopicPrefix("myapp"), // Changes topics to myapp/{client_id}/...
|
||||
)
|
||||
```
|
||||
|
||||
## Documentation References
|
||||
|
||||
- **ResolveSpec JSON Protocol**: See `/pkg/resolvespec/README.md` for the full message protocol specification
|
||||
- **WebSocketSpec Documentation**: See `/pkg/websocketspec/README.md` for similar WebSocket-based implementation
|
||||
- **Common Interfaces**: See `/pkg/common/types.go` for database adapter interfaces and query options
|
||||
- **Model Registry**: See `/pkg/modelregistry/README.md` for model registration and reflection
|
||||
- **Hooks Reference**: See `/pkg/websocketspec/hooks.go` for hook types (same as MQTTSpec)
|
||||
- **Subscription Management**: See `/pkg/websocketspec/subscription.go` for subscription filtering
|
||||
|
||||
## Comparison: MQTTSpec vs WebSocketSpec
|
||||
|
||||
| Feature | MQTTSpec | WebSocketSpec |
|
||||
|---------|----------|---------------|
|
||||
| **Transport** | MQTT (pub/sub broker) | WebSocket (direct connection) |
|
||||
| **Connection Model** | Broker-mediated | Direct client-server |
|
||||
| **QoS Levels** | QoS 0, 1, 2 support | No built-in QoS |
|
||||
| **Offline Messages** | Yes (with QoS 1+) | No |
|
||||
| **Auto-reconnect** | Yes (built into MQTT) | Manual implementation needed |
|
||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
||||
| **CRUD Operations** | Identical | Identical |
|
||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||
|
||||
## Use Cases
|
||||
|
||||
### IoT Sensor Data
|
||||
|
||||
```go
|
||||
// Sensors publish data, backend stores and notifies subscribers
|
||||
handler.Registry().RegisterModel("public.sensor_readings", &SensorReading{})
|
||||
|
||||
// Auto-set device_id from client metadata
|
||||
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
deviceID, _ := client.GetMetadata("device_id")
|
||||
|
||||
if ctx.Entity == "sensor_readings" {
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["device_id"] = deviceID
|
||||
dataMap["timestamp"] = time.Now()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Mobile App with Offline Support
|
||||
|
||||
MQTTSpec's QoS 1 ensures messages are delivered even if the client temporarily disconnects.
|
||||
|
||||
### Distributed Microservices
|
||||
|
||||
Multiple services can subscribe to entity changes and react accordingly.
|
||||
|
||||
## Testing
|
||||
|
||||
Run unit tests:
|
||||
|
||||
```bash
|
||||
go test -v ./pkg/mqttspec
|
||||
```
|
||||
|
||||
Run with race detection:
|
||||
|
||||
```bash
|
||||
go test -race -v ./pkg/mqttspec
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This package is part of the ResolveSpec project.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please ensure:
|
||||
|
||||
- All tests pass (`go test ./pkg/mqttspec`)
|
||||
- No race conditions (`go test -race ./pkg/mqttspec`)
|
||||
- Documentation is updated
|
||||
- Examples are provided for new features
|
||||
|
||||
## Support
|
||||
|
||||
For issues, questions, or feature requests, please open an issue in the ResolveSpec repository.
|
||||
417
pkg/mqttspec/broker.go
Normal file
417
pkg/mqttspec/broker.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
|
||||
pahomqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// BrokerInterface abstracts MQTT broker operations
|
||||
type BrokerInterface interface {
|
||||
// Start initializes the broker/client connection
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop gracefully shuts down the broker/client
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Publish sends a message to a topic
|
||||
Publish(topic string, qos byte, payload []byte) error
|
||||
|
||||
// Subscribe subscribes to a topic pattern with callback
|
||||
Subscribe(topicFilter string, qos byte, callback MessageCallback) error
|
||||
|
||||
// Unsubscribe removes subscription
|
||||
Unsubscribe(topicFilter string) error
|
||||
|
||||
// IsConnected returns connection status
|
||||
IsConnected() bool
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
GetClientManager() *ClientManager
|
||||
|
||||
// SetHandler sets the handler reference (needed for hooks)
|
||||
SetHandler(handler *Handler)
|
||||
}
|
||||
|
||||
// MessageCallback is called when a message arrives
|
||||
type MessageCallback func(topic string, payload []byte)
|
||||
|
||||
// EmbeddedBroker wraps Mochi MQTT server
|
||||
type EmbeddedBroker struct {
|
||||
config BrokerConfig
|
||||
server *mqtt.Server
|
||||
clientManager *ClientManager
|
||||
handler *Handler
|
||||
subscriptions map[string]MessageCallback
|
||||
subMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
}
|
||||
|
||||
// NewEmbeddedBroker creates a new embedded broker
|
||||
func NewEmbeddedBroker(config BrokerConfig, clientManager *ClientManager) *EmbeddedBroker {
|
||||
return &EmbeddedBroker{
|
||||
config: config,
|
||||
clientManager: clientManager,
|
||||
subscriptions: make(map[string]MessageCallback),
|
||||
}
|
||||
}
|
||||
|
||||
// SetHandler sets the handler reference
|
||||
func (eb *EmbeddedBroker) SetHandler(handler *Handler) {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
eb.handler = handler
|
||||
}
|
||||
|
||||
// Start starts the embedded MQTT broker
|
||||
func (eb *EmbeddedBroker) Start(ctx context.Context) error {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
if eb.started {
|
||||
return fmt.Errorf("broker already started")
|
||||
}
|
||||
|
||||
eb.ctx, eb.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Create Mochi MQTT server
|
||||
eb.server = mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
|
||||
// Note: Authentication is handled at the handler level via BeforeConnect hook
|
||||
// Mochi MQTT auth can be configured via custom hooks if needed
|
||||
|
||||
// Add TCP listener
|
||||
tcp := listeners.NewTCP(
|
||||
listeners.Config{
|
||||
ID: "tcp",
|
||||
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.Port),
|
||||
},
|
||||
)
|
||||
if err := eb.server.AddListener(tcp); err != nil {
|
||||
return fmt.Errorf("failed to add TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// Add WebSocket listener if enabled
|
||||
if eb.config.EnableWebSocket {
|
||||
ws := listeners.NewWebsocket(
|
||||
listeners.Config{
|
||||
ID: "ws",
|
||||
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.WSPort),
|
||||
},
|
||||
)
|
||||
if err := eb.server.AddListener(ws); err != nil {
|
||||
return fmt.Errorf("failed to add WebSocket listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := eb.server.Serve(); err != nil {
|
||||
logger.Error("[MQTTSpec] Embedded broker error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for server to be ready
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
// Server should be ready
|
||||
case <-eb.ctx.Done():
|
||||
return fmt.Errorf("context cancelled during startup")
|
||||
}
|
||||
|
||||
eb.started = true
|
||||
logger.Info("[MQTTSpec] Embedded broker started on %s:%d", eb.config.Host, eb.config.Port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the embedded broker
|
||||
func (eb *EmbeddedBroker) Stop(ctx context.Context) error {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
if !eb.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
if eb.cancel != nil {
|
||||
eb.cancel()
|
||||
}
|
||||
|
||||
if eb.server != nil {
|
||||
if err := eb.server.Close(); err != nil {
|
||||
logger.Error("[MQTTSpec] Error closing embedded broker: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eb.started = false
|
||||
logger.Info("[MQTTSpec] Embedded broker stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish publishes a message to a topic
|
||||
func (eb *EmbeddedBroker) Publish(topic string, qos byte, payload []byte) error {
|
||||
if !eb.started {
|
||||
return fmt.Errorf("broker not started")
|
||||
}
|
||||
|
||||
if eb.server == nil {
|
||||
return fmt.Errorf("server not initialized")
|
||||
}
|
||||
|
||||
// Use inline client to publish
|
||||
return eb.server.Publish(topic, payload, false, qos)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to a topic
|
||||
func (eb *EmbeddedBroker) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
|
||||
if !eb.started {
|
||||
return fmt.Errorf("broker not started")
|
||||
}
|
||||
|
||||
// Store callback
|
||||
eb.subMu.Lock()
|
||||
eb.subscriptions[topicFilter] = callback
|
||||
eb.subMu.Unlock()
|
||||
|
||||
// Create inline subscription handler
|
||||
// Note: Mochi MQTT internal subscriptions are more complex
|
||||
// For now, we'll use a publishing hook to intercept messages
|
||||
// This is a simplified implementation
|
||||
|
||||
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from a topic
|
||||
func (eb *EmbeddedBroker) Unsubscribe(topicFilter string) error {
|
||||
eb.subMu.Lock()
|
||||
defer eb.subMu.Unlock()
|
||||
|
||||
delete(eb.subscriptions, topicFilter)
|
||||
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected returns whether the broker is running
|
||||
func (eb *EmbeddedBroker) IsConnected() bool {
|
||||
eb.mu.RLock()
|
||||
defer eb.mu.RUnlock()
|
||||
return eb.started
|
||||
}
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
func (eb *EmbeddedBroker) GetClientManager() *ClientManager {
|
||||
return eb.clientManager
|
||||
}
|
||||
|
||||
// ExternalBrokerClient wraps Paho MQTT client
|
||||
type ExternalBrokerClient struct {
|
||||
config ExternalBrokerConfig
|
||||
client pahomqtt.Client
|
||||
clientManager *ClientManager
|
||||
handler *Handler
|
||||
subscriptions map[string]MessageCallback
|
||||
subMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
connected bool
|
||||
}
|
||||
|
||||
// NewExternalBrokerClient creates a new external broker client
|
||||
func NewExternalBrokerClient(config ExternalBrokerConfig, clientManager *ClientManager) *ExternalBrokerClient {
|
||||
return &ExternalBrokerClient{
|
||||
config: config,
|
||||
clientManager: clientManager,
|
||||
subscriptions: make(map[string]MessageCallback),
|
||||
}
|
||||
}
|
||||
|
||||
// SetHandler sets the handler reference
|
||||
func (ebc *ExternalBrokerClient) SetHandler(handler *Handler) {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
ebc.handler = handler
|
||||
}
|
||||
|
||||
// Start connects to the external MQTT broker
|
||||
func (ebc *ExternalBrokerClient) Start(ctx context.Context) error {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
|
||||
if ebc.connected {
|
||||
return fmt.Errorf("already connected")
|
||||
}
|
||||
|
||||
ebc.ctx, ebc.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Create Paho client options
|
||||
opts := pahomqtt.NewClientOptions()
|
||||
opts.AddBroker(ebc.config.BrokerURL)
|
||||
opts.SetClientID(ebc.config.ClientID)
|
||||
opts.SetUsername(ebc.config.Username)
|
||||
opts.SetPassword(ebc.config.Password)
|
||||
opts.SetCleanSession(ebc.config.CleanSession)
|
||||
opts.SetKeepAlive(ebc.config.KeepAlive)
|
||||
opts.SetAutoReconnect(true)
|
||||
opts.SetMaxReconnectInterval(ebc.config.ReconnectDelay)
|
||||
|
||||
// Set connection lost handler
|
||||
opts.SetConnectionLostHandler(func(client pahomqtt.Client, err error) {
|
||||
logger.Error("[MQTTSpec] External broker connection lost: %v", err)
|
||||
ebc.mu.Lock()
|
||||
ebc.connected = false
|
||||
ebc.mu.Unlock()
|
||||
})
|
||||
|
||||
// Set on-connect handler
|
||||
opts.SetOnConnectHandler(func(client pahomqtt.Client) {
|
||||
logger.Info("[MQTTSpec] Connected to external broker")
|
||||
ebc.mu.Lock()
|
||||
ebc.connected = true
|
||||
ebc.mu.Unlock()
|
||||
|
||||
// Resubscribe to topics
|
||||
ebc.resubscribeAll()
|
||||
})
|
||||
|
||||
// Create and connect client
|
||||
ebc.client = pahomqtt.NewClient(opts)
|
||||
token := ebc.client.Connect()
|
||||
|
||||
if !token.WaitTimeout(ebc.config.ConnectTimeout) {
|
||||
return fmt.Errorf("connection timeout")
|
||||
}
|
||||
|
||||
if err := token.Error(); err != nil {
|
||||
return fmt.Errorf("failed to connect to external broker: %w", err)
|
||||
}
|
||||
|
||||
ebc.connected = true
|
||||
logger.Info("[MQTTSpec] Connected to external MQTT broker: %s", ebc.config.BrokerURL)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop disconnects from the external broker
|
||||
func (ebc *ExternalBrokerClient) Stop(ctx context.Context) error {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
|
||||
if !ebc.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ebc.cancel != nil {
|
||||
ebc.cancel()
|
||||
}
|
||||
|
||||
if ebc.client != nil && ebc.client.IsConnected() {
|
||||
ebc.client.Disconnect(uint(ebc.config.ConnectTimeout.Milliseconds()))
|
||||
}
|
||||
|
||||
ebc.connected = false
|
||||
logger.Info("[MQTTSpec] Disconnected from external broker")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish publishes a message to a topic
|
||||
func (ebc *ExternalBrokerClient) Publish(topic string, qos byte, payload []byte) error {
|
||||
if !ebc.connected {
|
||||
return fmt.Errorf("not connected to broker")
|
||||
}
|
||||
|
||||
token := ebc.client.Publish(topic, qos, false, payload)
|
||||
token.Wait()
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
// Subscribe subscribes to a topic
|
||||
func (ebc *ExternalBrokerClient) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
|
||||
if !ebc.connected {
|
||||
return fmt.Errorf("not connected to broker")
|
||||
}
|
||||
|
||||
// Store callback
|
||||
ebc.subMu.Lock()
|
||||
ebc.subscriptions[topicFilter] = callback
|
||||
ebc.subMu.Unlock()
|
||||
|
||||
// Subscribe via Paho client
|
||||
token := ebc.client.Subscribe(topicFilter, qos, func(client pahomqtt.Client, msg pahomqtt.Message) {
|
||||
callback(msg.Topic(), msg.Payload())
|
||||
})
|
||||
|
||||
token.Wait()
|
||||
if err := token.Error(); err != nil {
|
||||
return fmt.Errorf("failed to subscribe to %s: %w", topicFilter, err)
|
||||
}
|
||||
|
||||
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from a topic
|
||||
func (ebc *ExternalBrokerClient) Unsubscribe(topicFilter string) error {
|
||||
ebc.subMu.Lock()
|
||||
defer ebc.subMu.Unlock()
|
||||
|
||||
if ebc.client != nil && ebc.connected {
|
||||
token := ebc.client.Unsubscribe(topicFilter)
|
||||
token.Wait()
|
||||
if err := token.Error(); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to unsubscribe from %s: %v", topicFilter, err)
|
||||
}
|
||||
}
|
||||
|
||||
delete(ebc.subscriptions, topicFilter)
|
||||
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected returns connection status
|
||||
func (ebc *ExternalBrokerClient) IsConnected() bool {
|
||||
ebc.mu.RLock()
|
||||
defer ebc.mu.RUnlock()
|
||||
return ebc.connected
|
||||
}
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
func (ebc *ExternalBrokerClient) GetClientManager() *ClientManager {
|
||||
return ebc.clientManager
|
||||
}
|
||||
|
||||
// resubscribeAll resubscribes to all topics after reconnection
|
||||
func (ebc *ExternalBrokerClient) resubscribeAll() {
|
||||
ebc.subMu.RLock()
|
||||
defer ebc.subMu.RUnlock()
|
||||
|
||||
for topicFilter, callback := range ebc.subscriptions {
|
||||
logger.Info("[MQTTSpec] Resubscribing to topic: %s", topicFilter)
|
||||
token := ebc.client.Subscribe(topicFilter, 1, func(client pahomqtt.Client, msg pahomqtt.Message) {
|
||||
callback(msg.Topic(), msg.Payload())
|
||||
})
|
||||
if token.Wait() && token.Error() != nil {
|
||||
logger.Error("[MQTTSpec] Failed to resubscribe to %s: %v", topicFilter, token.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
409
pkg/mqttspec/broker_test.go
Normal file
409
pkg/mqttspec/broker_test.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewEmbeddedBroker(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 1883,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
assert.NotNil(t, broker)
|
||||
assert.Equal(t, config, broker.config)
|
||||
assert.Equal(t, cm, broker.clientManager)
|
||||
assert.NotNil(t, broker.subscriptions)
|
||||
assert.False(t, broker.started)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StartStop(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11883, // Use non-standard port for testing
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify started
|
||||
assert.True(t, broker.IsConnected())
|
||||
|
||||
// Stop broker
|
||||
err = broker.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify stopped
|
||||
assert.False(t, broker.IsConnected())
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StartTwice(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11884,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Try to start again - should fail
|
||||
err = broker.Start(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StopWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11885,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Stop without starting - should not error
|
||||
err := broker.Stop(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_PublishWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11886,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Try to publish without starting - should fail
|
||||
err := broker.Publish("test/topic", 1, []byte("test"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "broker not started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_SubscribeWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11887,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Try to subscribe without starting - should fail
|
||||
err := broker.Subscribe("test/topic", 1, func(topic string, payload []byte) {})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "broker not started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_PublishSubscribe(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11888,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Subscribe to topic
|
||||
callback := func(topic string, payload []byte) {
|
||||
// Callback for subscription - actual message delivery would require
|
||||
// integration with Mochi MQTT's hook system
|
||||
}
|
||||
|
||||
err = broker.Subscribe("test/topic", 1, callback)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Embedded broker's Subscribe is simplified and doesn't fully integrate
|
||||
// with Mochi MQTT's internal pub/sub. This test verifies the subscription
|
||||
// is registered but actual message delivery would require more complex
|
||||
// integration with Mochi MQTT's hook system.
|
||||
|
||||
// Verify subscription was registered
|
||||
broker.subMu.RLock()
|
||||
_, exists := broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_Unsubscribe(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11889,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Subscribe
|
||||
callback := func(topic string, payload []byte) {}
|
||||
err = broker.Subscribe("test/topic", 1, callback)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription exists
|
||||
broker.subMu.RLock()
|
||||
_, exists := broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.True(t, exists)
|
||||
|
||||
// Unsubscribe
|
||||
err = broker.Unsubscribe("test/topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription removed
|
||||
broker.subMu.RLock()
|
||||
_, exists = broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_SetHandler(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11890,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Create a mock handler (nil is fine for this test)
|
||||
var handler *Handler = nil
|
||||
|
||||
// Set handler
|
||||
broker.SetHandler(handler)
|
||||
|
||||
// Verify handler was set
|
||||
broker.mu.RLock()
|
||||
assert.Equal(t, handler, broker.handler)
|
||||
broker.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_GetClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11891,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Get client manager
|
||||
retrievedCM := broker.GetClientManager()
|
||||
|
||||
// Verify it's the same instance
|
||||
assert.Equal(t, cm, retrievedCM)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_ConcurrentPublish(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11892,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Test concurrent publishing
|
||||
var wg sync.WaitGroup
|
||||
numPublishers := 10
|
||||
|
||||
for i := 0; i < numPublishers; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
err := broker.Publish("test/topic", 1, []byte("test"))
|
||||
// Errors are acceptable in concurrent scenario
|
||||
_ = err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNewExternalBrokerClient(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
assert.NotNil(t, broker)
|
||||
assert.Equal(t, config, broker.config)
|
||||
assert.Equal(t, cm, broker.clientManager)
|
||||
assert.NotNil(t, broker.subscriptions)
|
||||
assert.False(t, broker.connected)
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_SetHandler(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Create a mock handler (nil is fine for this test)
|
||||
var handler *Handler = nil
|
||||
|
||||
// Set handler
|
||||
broker.SetHandler(handler)
|
||||
|
||||
// Verify handler was set
|
||||
broker.mu.RLock()
|
||||
assert.Equal(t, handler, broker.handler)
|
||||
broker.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_GetClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Get client manager
|
||||
retrievedCM := broker.GetClientManager()
|
||||
|
||||
// Verify it's the same instance
|
||||
assert.Equal(t, cm, retrievedCM)
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_IsConnected(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Should not be connected initially
|
||||
assert.False(t, broker.IsConnected())
|
||||
}
|
||||
|
||||
// Note: Tests for ExternalBrokerClient Start/Stop/Publish/Subscribe require
|
||||
// a running MQTT broker and are better suited for integration tests.
|
||||
// These tests would be included in integration_test.go with proper test
|
||||
// broker setup (e.g., using Docker Compose).
|
||||
184
pkg/mqttspec/client.go
Normal file
184
pkg/mqttspec/client.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Client represents an MQTT client connection
|
||||
type Client struct {
|
||||
// ID is the MQTT client ID (unique per connection)
|
||||
ID string
|
||||
|
||||
// Username from MQTT CONNECT packet
|
||||
Username string
|
||||
|
||||
// ConnectedAt is when the client connected
|
||||
ConnectedAt time.Time
|
||||
|
||||
// subscriptions holds active subscriptions for this client
|
||||
subscriptions map[string]*Subscription
|
||||
subMu sync.RWMutex
|
||||
|
||||
// metadata stores client-specific data (user_id, roles, tenant_id, etc.)
|
||||
// Set by BeforeConnect hook for authentication/authorization
|
||||
metadata map[string]interface{}
|
||||
metaMu sync.RWMutex
|
||||
|
||||
// ctx is the client context
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// handler reference for callback access
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
// ClientManager manages all MQTT client connections
|
||||
type ClientManager struct {
|
||||
// clients maps client_id to Client
|
||||
clients map[string]*Client
|
||||
mu sync.RWMutex
|
||||
|
||||
// ctx for lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewClient creates a new MQTT client
|
||||
func NewClient(id, username string, handler *Handler) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
ID: id,
|
||||
Username: username,
|
||||
ConnectedAt: time.Now(),
|
||||
subscriptions: make(map[string]*Subscription),
|
||||
metadata: make(map[string]interface{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetadata sets metadata for this client
|
||||
func (c *Client) SetMetadata(key string, value interface{}) {
|
||||
c.metaMu.Lock()
|
||||
defer c.metaMu.Unlock()
|
||||
c.metadata[key] = value
|
||||
}
|
||||
|
||||
// GetMetadata retrieves metadata for this client
|
||||
func (c *Client) GetMetadata(key string) (interface{}, bool) {
|
||||
c.metaMu.RLock()
|
||||
defer c.metaMu.RUnlock()
|
||||
val, ok := c.metadata[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// AddSubscription adds a subscription to this client
|
||||
func (c *Client) AddSubscription(sub *Subscription) {
|
||||
c.subMu.Lock()
|
||||
defer c.subMu.Unlock()
|
||||
c.subscriptions[sub.ID] = sub
|
||||
}
|
||||
|
||||
// RemoveSubscription removes a subscription from this client
|
||||
func (c *Client) RemoveSubscription(subID string) {
|
||||
c.subMu.Lock()
|
||||
defer c.subMu.Unlock()
|
||||
delete(c.subscriptions, subID)
|
||||
}
|
||||
|
||||
// GetSubscription retrieves a subscription by ID
|
||||
func (c *Client) GetSubscription(subID string) (*Subscription, bool) {
|
||||
c.subMu.RLock()
|
||||
defer c.subMu.RUnlock()
|
||||
sub, ok := c.subscriptions[subID]
|
||||
return sub, ok
|
||||
}
|
||||
|
||||
// Close cleans up the client
|
||||
func (c *Client) Close() {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Clean up subscriptions
|
||||
c.subMu.Lock()
|
||||
for subID := range c.subscriptions {
|
||||
if c.handler != nil && c.handler.subscriptionManager != nil {
|
||||
c.handler.subscriptionManager.Unsubscribe(subID)
|
||||
}
|
||||
}
|
||||
c.subscriptions = make(map[string]*Subscription)
|
||||
c.subMu.Unlock()
|
||||
}
|
||||
|
||||
// NewClientManager creates a new client manager
|
||||
func NewClientManager(ctx context.Context) *ClientManager {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &ClientManager{
|
||||
clients: make(map[string]*Client),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new MQTT client
|
||||
func (cm *ClientManager) Register(clientID, username string, handler *Handler) *Client {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
|
||||
client := NewClient(clientID, username, handler)
|
||||
cm.clients[clientID] = client
|
||||
|
||||
count := len(cm.clients)
|
||||
logger.Info("[MQTTSpec] Client registered: %s (username: %s, total: %d)", clientID, username, count)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// Unregister removes a client
|
||||
func (cm *ClientManager) Unregister(clientID string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
|
||||
if client, ok := cm.clients[clientID]; ok {
|
||||
client.Close()
|
||||
delete(cm.clients, clientID)
|
||||
count := len(cm.clients)
|
||||
logger.Info("[MQTTSpec] Client unregistered: %s (total: %d)", clientID, count)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient retrieves a client by ID
|
||||
func (cm *ClientManager) GetClient(clientID string) (*Client, bool) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
client, ok := cm.clients[clientID]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// Count returns the number of active clients
|
||||
func (cm *ClientManager) Count() int {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return len(cm.clients)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the client manager
|
||||
func (cm *ClientManager) Shutdown() {
|
||||
cm.cancel()
|
||||
|
||||
// Close all clients
|
||||
cm.mu.Lock()
|
||||
for _, client := range cm.clients {
|
||||
client.Close()
|
||||
}
|
||||
cm.clients = make(map[string]*Client)
|
||||
cm.mu.Unlock()
|
||||
|
||||
logger.Info("[MQTTSpec] Client manager shut down")
|
||||
}
|
||||
256
pkg/mqttspec/client_test.go
Normal file
256
pkg/mqttspec/client_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient("client-123", "user@example.com", nil)
|
||||
|
||||
assert.Equal(t, "client-123", client.ID)
|
||||
assert.Equal(t, "user@example.com", client.Username)
|
||||
assert.NotNil(t, client.subscriptions)
|
||||
assert.NotNil(t, client.metadata)
|
||||
assert.NotNil(t, client.ctx)
|
||||
assert.NotNil(t, client.cancel)
|
||||
}
|
||||
|
||||
func TestClient_Metadata(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Set metadata
|
||||
client.SetMetadata("user_id", 456)
|
||||
client.SetMetadata("tenant_id", "tenant-abc")
|
||||
client.SetMetadata("roles", []string{"admin", "user"})
|
||||
|
||||
// Get metadata
|
||||
userID, exists := client.GetMetadata("user_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 456, userID)
|
||||
|
||||
tenantID, exists := client.GetMetadata("tenant_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "tenant-abc", tenantID)
|
||||
|
||||
roles, exists := client.GetMetadata("roles")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []string{"admin", "user"}, roles)
|
||||
|
||||
// Non-existent key
|
||||
_, exists = client.GetMetadata("nonexistent")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClient_Subscriptions(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Create mock subscription
|
||||
sub := &Subscription{
|
||||
ID: "sub-1",
|
||||
ConnectionID: "client-123",
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Active: true,
|
||||
}
|
||||
|
||||
// Add subscription
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Get subscription
|
||||
retrieved, exists := client.GetSubscription("sub-1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "sub-1", retrieved.ID)
|
||||
|
||||
// Remove subscription
|
||||
client.RemoveSubscription("sub-1")
|
||||
|
||||
// Verify removed
|
||||
_, exists = client.GetSubscription("sub-1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClient_Close(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Add some subscriptions
|
||||
client.AddSubscription(&Subscription{ID: "sub-1"})
|
||||
client.AddSubscription(&Subscription{ID: "sub-2"})
|
||||
|
||||
// Close client
|
||||
client.Close()
|
||||
|
||||
// Verify subscriptions cleared
|
||||
client.subMu.RLock()
|
||||
assert.Empty(t, client.subscriptions)
|
||||
client.subMu.RUnlock()
|
||||
|
||||
// Verify context cancelled
|
||||
select {
|
||||
case <-client.ctx.Done():
|
||||
// Context was cancelled
|
||||
default:
|
||||
t.Fatal("Context should be cancelled after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
|
||||
assert.NotNil(t, cm)
|
||||
assert.NotNil(t, cm.clients)
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_Register(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
client := cm.Register("client-1", "user@example.com", nil)
|
||||
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, "client-1", client.ID)
|
||||
assert.Equal(t, "user@example.com", client.Username)
|
||||
assert.Equal(t, 1, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_Unregister(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
assert.Equal(t, 1, cm.Count())
|
||||
|
||||
cm.Unregister("client-1")
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_GetClient(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
|
||||
// Get existing client
|
||||
client, exists := cm.GetClient("client-1")
|
||||
assert.True(t, exists)
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, "client-1", client.ID)
|
||||
|
||||
// Get non-existent client
|
||||
_, exists = cm.GetClient("nonexistent")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClientManager_MultipleClients(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
cm.Register("client-2", "user2", nil)
|
||||
cm.Register("client-3", "user3", nil)
|
||||
|
||||
assert.Equal(t, 3, cm.Count())
|
||||
|
||||
cm.Unregister("client-2")
|
||||
assert.Equal(t, 2, cm.Count())
|
||||
|
||||
// Verify correct client was removed
|
||||
_, exists := cm.GetClient("client-2")
|
||||
assert.False(t, exists)
|
||||
|
||||
_, exists = cm.GetClient("client-1")
|
||||
assert.True(t, exists)
|
||||
|
||||
_, exists = cm.GetClient("client-3")
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestClientManager_Shutdown(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
cm.Register("client-2", "user2", nil)
|
||||
assert.Equal(t, 2, cm.Count())
|
||||
|
||||
cm.Shutdown()
|
||||
|
||||
// All clients should be removed
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
|
||||
// Context should be cancelled
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
// Context was cancelled
|
||||
default:
|
||||
t.Fatal("Context should be cancelled after Shutdown()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientManager_ConcurrentOperations(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
// This test verifies that concurrent operations don't cause race conditions
|
||||
// Run with: go test -race
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Goroutine 1: Register clients
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.Register("client-"+string(rune(i)), "user", nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 2: Get clients
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.GetClient("client-" + string(rune(i)))
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 3: Count
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.Count()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestClient_ConcurrentMetadata(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent writes
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
client.SetMetadata("key1", i)
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
client.GetMetadata("key1")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
178
pkg/mqttspec/config.go
Normal file
178
pkg/mqttspec/config.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BrokerMode specifies how to connect to MQTT
|
||||
type BrokerMode string
|
||||
|
||||
const (
|
||||
// BrokerModeEmbedded runs Mochi MQTT broker in-process
|
||||
BrokerModeEmbedded BrokerMode = "embedded"
|
||||
// BrokerModeExternal connects to external MQTT broker as client
|
||||
BrokerModeExternal BrokerMode = "external"
|
||||
)
|
||||
|
||||
// Config holds all mqttspec configuration
|
||||
type Config struct {
|
||||
// BrokerMode determines whether to use embedded or external broker
|
||||
BrokerMode BrokerMode
|
||||
|
||||
// Broker configuration for embedded mode
|
||||
Broker BrokerConfig
|
||||
|
||||
// ExternalBroker configuration for external client mode
|
||||
ExternalBroker ExternalBrokerConfig
|
||||
|
||||
// Topics configuration
|
||||
Topics TopicConfig
|
||||
|
||||
// QoS configuration for different message types
|
||||
QoS QoSConfig
|
||||
|
||||
// Auth configuration
|
||||
Auth AuthConfig
|
||||
|
||||
// Timeouts for various operations
|
||||
Timeouts TimeoutConfig
|
||||
}
|
||||
|
||||
// BrokerConfig configures the embedded Mochi MQTT broker
|
||||
type BrokerConfig struct {
|
||||
// Host to bind to (default: "localhost")
|
||||
Host string
|
||||
|
||||
// Port to listen on (default: 1883)
|
||||
Port int
|
||||
|
||||
// EnableWebSocket enables WebSocket support
|
||||
EnableWebSocket bool
|
||||
|
||||
// WSPort is the WebSocket port (default: 8883)
|
||||
WSPort int
|
||||
|
||||
// MaxConnections limits concurrent client connections
|
||||
MaxConnections int
|
||||
|
||||
// KeepAlive is the client keepalive interval
|
||||
KeepAlive time.Duration
|
||||
|
||||
// EnableAuth enables username/password authentication
|
||||
EnableAuth bool
|
||||
}
|
||||
|
||||
// ExternalBrokerConfig for connecting as a client to external broker
|
||||
type ExternalBrokerConfig struct {
|
||||
// BrokerURL is the broker address (e.g., tcp://host:port or ssl://host:port)
|
||||
BrokerURL string
|
||||
|
||||
// ClientID is a unique identifier for this handler instance
|
||||
ClientID string
|
||||
|
||||
// Username for MQTT authentication
|
||||
Username string
|
||||
|
||||
// Password for MQTT authentication
|
||||
Password string
|
||||
|
||||
// CleanSession flag (default: true)
|
||||
CleanSession bool
|
||||
|
||||
// KeepAlive interval (default: 60s)
|
||||
KeepAlive time.Duration
|
||||
|
||||
// ConnectTimeout for initial connection (default: 30s)
|
||||
ConnectTimeout time.Duration
|
||||
|
||||
// ReconnectDelay between reconnection attempts (default: 5s)
|
||||
ReconnectDelay time.Duration
|
||||
|
||||
// MaxReconnect attempts (0 = unlimited, default: 0)
|
||||
MaxReconnect int
|
||||
|
||||
// TLSConfig for SSL/TLS connections
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TopicConfig defines the MQTT topic structure
|
||||
type TopicConfig struct {
|
||||
// Prefix for all topics (default: "spec")
|
||||
// Topics will be: {Prefix}/{client_id}/request|response|notify/{sub_id}
|
||||
Prefix string
|
||||
}
|
||||
|
||||
// QoSConfig defines quality of service levels for different message types
|
||||
type QoSConfig struct {
|
||||
// Request messages QoS (default: 1 - at-least-once)
|
||||
Request byte
|
||||
|
||||
// Response messages QoS (default: 1 - at-least-once)
|
||||
Response byte
|
||||
|
||||
// Notification messages QoS (default: 1 - at-least-once)
|
||||
Notification byte
|
||||
}
|
||||
|
||||
// AuthConfig for MQTT-level authentication
|
||||
type AuthConfig struct {
|
||||
// ValidateCredentials is called to validate username/password for embedded broker
|
||||
// Return true if credentials are valid, false otherwise
|
||||
ValidateCredentials func(username, password string) bool
|
||||
}
|
||||
|
||||
// TimeoutConfig defines timeouts for various operations
|
||||
type TimeoutConfig struct {
|
||||
// Connect timeout for MQTT connection (default: 30s)
|
||||
Connect time.Duration
|
||||
|
||||
// Publish timeout for publishing messages (default: 5s)
|
||||
Publish time.Duration
|
||||
|
||||
// Disconnect timeout for graceful shutdown (default: 10s)
|
||||
Disconnect time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns a configuration with sensible defaults
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
BrokerMode: BrokerModeEmbedded,
|
||||
Broker: BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 1883,
|
||||
EnableWebSocket: false,
|
||||
WSPort: 8883,
|
||||
MaxConnections: 1000,
|
||||
KeepAlive: 60 * time.Second,
|
||||
EnableAuth: false,
|
||||
},
|
||||
ExternalBroker: ExternalBrokerConfig{
|
||||
BrokerURL: "",
|
||||
ClientID: "",
|
||||
Username: "",
|
||||
Password: "",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 30 * time.Second,
|
||||
ReconnectDelay: 5 * time.Second,
|
||||
MaxReconnect: 0, // Unlimited
|
||||
},
|
||||
Topics: TopicConfig{
|
||||
Prefix: "spec",
|
||||
},
|
||||
QoS: QoSConfig{
|
||||
Request: 1, // At-least-once
|
||||
Response: 1, // At-least-once
|
||||
Notification: 1, // At-least-once
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
ValidateCredentials: nil,
|
||||
},
|
||||
Timeouts: TimeoutConfig{
|
||||
Connect: 30 * time.Second,
|
||||
Publish: 5 * time.Second,
|
||||
Disconnect: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
846
pkg/mqttspec/handler.go
Normal file
846
pkg/mqttspec/handler.go
Normal file
@@ -0,0 +1,846 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler handles MQTT messages and operations
|
||||
type Handler struct {
|
||||
// Database adapter (GORM/Bun)
|
||||
db common.Database
|
||||
|
||||
// Model registry
|
||||
registry common.ModelRegistry
|
||||
|
||||
// Hook registry
|
||||
hooks *HookRegistry
|
||||
|
||||
// Client manager
|
||||
clientManager *ClientManager
|
||||
|
||||
// Subscription manager
|
||||
subscriptionManager *SubscriptionManager
|
||||
|
||||
// Broker interface (embedded or external)
|
||||
broker BrokerInterface
|
||||
|
||||
// Configuration
|
||||
config *Config
|
||||
|
||||
// Context for lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Started flag
|
||||
started bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new MQTT handler
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, config *Config) (*Handler, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
clientManager: NewClientManager(ctx),
|
||||
subscriptionManager: NewSubscriptionManager(),
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
started: false,
|
||||
}
|
||||
|
||||
// Initialize broker based on mode
|
||||
if config.BrokerMode == BrokerModeEmbedded {
|
||||
h.broker = NewEmbeddedBroker(config.Broker, h.clientManager)
|
||||
} else {
|
||||
h.broker = NewExternalBrokerClient(config.ExternalBroker, h.clientManager)
|
||||
}
|
||||
|
||||
// Set handler reference in broker
|
||||
h.broker.SetHandler(h)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Start initializes and starts the handler
|
||||
func (h *Handler) Start() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.started {
|
||||
return fmt.Errorf("handler already started")
|
||||
}
|
||||
|
||||
// Start broker
|
||||
if err := h.broker.Start(h.ctx); err != nil {
|
||||
return fmt.Errorf("failed to start broker: %w", err)
|
||||
}
|
||||
|
||||
// Subscribe to all request topics: spec/+/request
|
||||
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
|
||||
if err := h.broker.Subscribe(requestTopic, h.config.QoS.Request, h.handleIncomingMessage); err != nil {
|
||||
_ = h.broker.Stop(h.ctx)
|
||||
return fmt.Errorf("failed to subscribe to request topic: %w", err)
|
||||
}
|
||||
|
||||
h.started = true
|
||||
logger.Info("[MQTTSpec] Handler started, listening on topic: %s", requestTopic)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler
|
||||
func (h *Handler) Shutdown() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if !h.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("[MQTTSpec] Shutting down handler...")
|
||||
|
||||
// Execute disconnect hooks for all clients
|
||||
h.clientManager.mu.RLock()
|
||||
clients := make([]*Client, 0, len(h.clientManager.clients))
|
||||
for _, client := range h.clientManager.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.clientManager.mu.RUnlock()
|
||||
|
||||
for _, client := range clients {
|
||||
hookCtx := &HookContext{
|
||||
Context: h.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
_ = h.hooks.Execute(BeforeDisconnect, hookCtx)
|
||||
h.clientManager.Unregister(client.ID)
|
||||
_ = h.hooks.Execute(AfterDisconnect, hookCtx)
|
||||
}
|
||||
|
||||
// Unsubscribe from request topic
|
||||
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
|
||||
_ = h.broker.Unsubscribe(requestTopic)
|
||||
|
||||
// Stop broker
|
||||
if err := h.broker.Stop(h.ctx); err != nil {
|
||||
logger.Error("[MQTTSpec] Error stopping broker: %v", err)
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
if h.cancel != nil {
|
||||
h.cancel()
|
||||
}
|
||||
|
||||
h.started = false
|
||||
logger.Info("[MQTTSpec] Handler stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// Registry returns the model registry
|
||||
func (h *Handler) Registry() common.ModelRegistry {
|
||||
return h.registry
|
||||
}
|
||||
|
||||
// GetDatabase returns the database adapter
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// GetRelationshipInfo is a placeholder for relationship detection
|
||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||
// TODO: Implement full relationship detection if needed
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleIncomingMessage is called when a message arrives on spec/+/request
|
||||
func (h *Handler) handleIncomingMessage(topic string, payload []byte) {
|
||||
// Extract client_id from topic: spec/{client_id}/request
|
||||
parts := strings.Split(topic, "/")
|
||||
if len(parts) < 3 {
|
||||
logger.Error("[MQTTSpec] Invalid topic format: %s", topic)
|
||||
return
|
||||
}
|
||||
clientID := parts[len(parts)-2] // Second to last part is client_id
|
||||
|
||||
// Parse message
|
||||
msg, err := ParseMessage(payload)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to parse message from %s: %v", clientID, err)
|
||||
h.sendError(clientID, "", "invalid_message", "Failed to parse message")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if !msg.IsValid() {
|
||||
logger.Error("[MQTTSpec] Invalid message from %s", clientID)
|
||||
h.sendError(clientID, msg.ID, "invalid_message", "Message validation failed")
|
||||
return
|
||||
}
|
||||
|
||||
// Get or register client
|
||||
client, exists := h.clientManager.GetClient(clientID)
|
||||
if !exists {
|
||||
// First request from this client - register it
|
||||
client = h.clientManager.Register(clientID, "", h)
|
||||
|
||||
// Execute connect hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: h.ctx,
|
||||
Handler: nil, // Not used for MQTT, handler ref stored in metadata if needed
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeConnect, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeConnect hook failed for %s: %v", clientID, err)
|
||||
h.sendError(clientID, msg.ID, "auth_error", err.Error())
|
||||
h.clientManager.Unregister(clientID)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.hooks.Execute(AfterConnect, hookCtx)
|
||||
}
|
||||
|
||||
// Route message by type
|
||||
switch msg.Type {
|
||||
case MessageTypeRequest:
|
||||
h.handleRequest(client, msg)
|
||||
case MessageTypeSubscription:
|
||||
h.handleSubscription(client, msg)
|
||||
case MessageTypePing:
|
||||
h.handlePing(client, msg)
|
||||
default:
|
||||
h.sendError(clientID, msg.ID, "invalid_message_type", fmt.Sprintf("Unknown message type: %s", msg.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest processes CRUD requests
|
||||
func (h *Handler) handleRequest(client *Client, msg *Message) {
|
||||
ctx := client.ctx
|
||||
schema := msg.Schema
|
||||
entity := msg.Entity
|
||||
recordID := msg.RecordID
|
||||
|
||||
// Get model from registry
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Model not found for %s.%s: %v", schema, entity, err)
|
||||
h.sendError(client.ID, msg.ID, "model_not_found", fmt.Sprintf("Model not found: %s.%s", schema, entity))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and unwrap model
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Model validation failed for %s.%s: %v", schema, entity, err)
|
||||
h.sendError(client.ID, msg.ID, "invalid_model", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
modelPtr := result.ModelPtr
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
ModelPtr: modelPtr,
|
||||
Options: msg.Options,
|
||||
ID: recordID,
|
||||
Data: msg.Data,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
h.handleRead(client, msg, hookCtx)
|
||||
case OperationCreate:
|
||||
h.handleCreate(client, msg, hookCtx)
|
||||
case OperationUpdate:
|
||||
h.handleUpdate(client, msg, hookCtx)
|
||||
case OperationDelete:
|
||||
h.handleDelete(client, msg, hookCtx)
|
||||
case OperationMeta:
|
||||
h.handleMeta(client, msg, hookCtx)
|
||||
default:
|
||||
h.sendError(client.ID, msg.ID, "invalid_operation", fmt.Sprintf("Unknown operation: %s", msg.Operation))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRead processes a read operation
|
||||
func (h *Handler) handleRead(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeRead hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform read operation
|
||||
var data interface{}
|
||||
var metadata map[string]interface{}
|
||||
var err error
|
||||
|
||||
if hookCtx.ID != "" {
|
||||
// Read single record by ID
|
||||
data, err = h.readByID(hookCtx)
|
||||
metadata = map[string]interface{}{"total": 1}
|
||||
} else {
|
||||
// Read multiple records
|
||||
data, metadata, err = h.readMultiple(hookCtx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Read operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "read_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterRead hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, metadata)
|
||||
}
|
||||
|
||||
// handleCreate processes a create operation
|
||||
func (h *Handler) handleCreate(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeCreate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform create operation
|
||||
data, err := h.create(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Create operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "create_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterCreate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationCreate, data)
|
||||
}
|
||||
|
||||
// handleUpdate processes an update operation
|
||||
func (h *Handler) handleUpdate(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeUpdate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform update operation
|
||||
data, err := h.update(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Update operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "update_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterUpdate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationUpdate, data)
|
||||
}
|
||||
|
||||
// handleDelete processes a delete operation
|
||||
func (h *Handler) handleDelete(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeDelete hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform delete operation
|
||||
if err := h.delete(hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] Delete operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "delete_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterDelete hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{"deleted": true}, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationDelete, map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// handleMeta processes a metadata request
|
||||
func (h *Handler) handleMeta(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
metadata, err := h.getMetadata(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Meta operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "meta_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(client.ID, msg.ID, metadata, nil)
|
||||
}
|
||||
|
||||
// handleSubscription manages subscriptions
|
||||
func (h *Handler) handleSubscription(client *Client, msg *Message) {
|
||||
switch msg.Operation {
|
||||
case OperationSubscribe:
|
||||
h.handleSubscribe(client, msg)
|
||||
case OperationUnsubscribe:
|
||||
h.handleUnsubscribe(client, msg)
|
||||
default:
|
||||
h.sendError(client.ID, msg.ID, "invalid_subscription_operation", fmt.Sprintf("Unknown subscription operation: %s", msg.Operation))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSubscribe creates a subscription
|
||||
func (h *Handler) handleSubscribe(client *Client, msg *Message) {
|
||||
// Generate subscription ID
|
||||
subID := uuid.New().String()
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: client.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Schema: msg.Schema,
|
||||
Entity: msg.Entity,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeSubscribe, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeSubscribe hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Create subscription
|
||||
sub := h.subscriptionManager.Subscribe(subID, client.ID, msg.Schema, msg.Entity, msg.Options)
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Execute after hook
|
||||
_ = h.hooks.Execute(AfterSubscribe, hookCtx)
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
|
||||
"subscription_id": subID,
|
||||
"schema": msg.Schema,
|
||||
"entity": msg.Entity,
|
||||
"notify_topic": h.getNotifyTopic(client.ID, subID),
|
||||
}, nil)
|
||||
|
||||
logger.Info("[MQTTSpec] Subscription created: %s for %s.%s (client: %s)", subID, msg.Schema, msg.Entity, client.ID)
|
||||
}
|
||||
|
||||
// handleUnsubscribe removes a subscription
|
||||
func (h *Handler) handleUnsubscribe(client *Client, msg *Message) {
|
||||
subID := msg.SubscriptionID
|
||||
if subID == "" {
|
||||
h.sendError(client.ID, msg.ID, "invalid_subscription", "Subscription ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: client.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeUnsubscribe, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeUnsubscribe hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Remove subscription
|
||||
h.subscriptionManager.Unsubscribe(subID)
|
||||
client.RemoveSubscription(subID)
|
||||
|
||||
// Execute after hook
|
||||
_ = h.hooks.Execute(AfterUnsubscribe, hookCtx)
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
|
||||
"unsubscribed": true,
|
||||
"subscription_id": subID,
|
||||
}, nil)
|
||||
|
||||
logger.Info("[MQTTSpec] Subscription removed: %s (client: %s)", subID, client.ID)
|
||||
}
|
||||
|
||||
// handlePing responds to ping messages
|
||||
func (h *Handler) handlePing(client *Client, msg *Message) {
|
||||
pong := &ResponseMessage{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypePong,
|
||||
Success: true,
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(pong)
|
||||
topic := h.getResponseTopic(client.ID)
|
||||
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
|
||||
}
|
||||
|
||||
// notifySubscribers sends notifications to subscribers
|
||||
func (h *Handler) notifySubscribers(schema, entity string, operation OperationType, data interface{}) {
|
||||
subscriptions := h.subscriptionManager.GetSubscriptionsByEntity(schema, entity)
|
||||
if len(subscriptions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
// Check if data matches subscription filters
|
||||
if !sub.MatchesFilters(data) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get client
|
||||
client, exists := h.clientManager.GetClient(sub.ConnectionID)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create notification message
|
||||
notification := NewNotificationMessage(sub.ID, operation, schema, entity, data)
|
||||
payload, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to marshal notification: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Publish to client's notify topic
|
||||
topic := h.getNotifyTopic(client.ID, sub.ID)
|
||||
if err := h.broker.Publish(topic, h.config.QoS.Notification, payload); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to publish notification to %s: %v", topic, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Response helpers
|
||||
|
||||
// sendResponse publishes a response message
|
||||
func (h *Handler) sendResponse(clientID, msgID string, data interface{}, metadata map[string]interface{}) {
|
||||
resp := NewResponseMessage(msgID, true, data)
|
||||
resp.Metadata = metadata
|
||||
|
||||
payload, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
topic := h.getResponseTopic(clientID)
|
||||
if err := h.broker.Publish(topic, h.config.QoS.Response, payload); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to publish response to %s: %v", topic, err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendError publishes an error response
|
||||
func (h *Handler) sendError(clientID, msgID, code, message string) {
|
||||
errResp := NewErrorResponse(msgID, code, message)
|
||||
|
||||
payload, _ := json.Marshal(errResp)
|
||||
topic := h.getResponseTopic(clientID)
|
||||
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
|
||||
}
|
||||
|
||||
// Topic helpers
|
||||
|
||||
func (h *Handler) getRequestTopic(clientID string) string {
|
||||
return fmt.Sprintf("%s/%s/request", h.config.Topics.Prefix, clientID)
|
||||
}
|
||||
|
||||
func (h *Handler) getResponseTopic(clientID string) string {
|
||||
return fmt.Sprintf("%s/%s/response", h.config.Topics.Prefix, clientID)
|
||||
}
|
||||
|
||||
func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
|
||||
return fmt.Sprintf("%s/%s/notify/%s", h.config.Topics.Prefix, clientID, subscriptionID)
|
||||
}
|
||||
|
||||
// Database operation helpers (adapted from websocketspec)
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
// readByID reads a single record by ID
|
||||
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
// Apply columns
|
||||
if hookCtx.Options != nil && len(hookCtx.Options.Columns) > 0 {
|
||||
query = query.Column(hookCtx.Options.Columns...)
|
||||
}
|
||||
|
||||
// Apply preloads (simplified)
|
||||
if hookCtx.Options != nil {
|
||||
for i := range hookCtx.Options.Preload {
|
||||
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := query.ScanModel(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to read record: %w", err)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
// readMultiple reads multiple records
|
||||
func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata map[string]interface{}, err error) {
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Apply options
|
||||
if hookCtx.Options != nil {
|
||||
// Apply filters
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
for _, sort := range hookCtx.Options.Sort {
|
||||
direction := "ASC"
|
||||
if sort.Direction == "desc" {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if hookCtx.Options.Limit != nil {
|
||||
query = query.Limit(*hookCtx.Options.Limit)
|
||||
}
|
||||
if hookCtx.Options.Offset != nil {
|
||||
query = query.Offset(*hookCtx.Options.Offset)
|
||||
}
|
||||
|
||||
// Apply preloads
|
||||
for i := range hookCtx.Options.Preload {
|
||||
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
|
||||
}
|
||||
|
||||
// Apply columns
|
||||
if len(hookCtx.Options.Columns) > 0 {
|
||||
query = query.Column(hookCtx.Options.Columns...)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := query.ScanModel(hookCtx.Context); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read records: %w", err)
|
||||
}
|
||||
|
||||
// Get count
|
||||
metadata = make(map[string]interface{})
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if hookCtx.Options != nil {
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
}
|
||||
count, _ := countQuery.Count(hookCtx.Context)
|
||||
metadata["total"] = count
|
||||
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
|
||||
|
||||
return hookCtx.ModelPtr, metadata, nil
|
||||
}
|
||||
|
||||
// create creates a new record
|
||||
func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
// Marshal and unmarshal data into model
|
||||
dataBytes, err := json.Marshal(hookCtx.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
|
||||
}
|
||||
|
||||
// Insert record
|
||||
query := h.db.NewInsert().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
// update updates an existing record
|
||||
func (h *Handler) update(hookCtx *HookContext) (interface{}, error) {
|
||||
// Marshal and unmarshal data into model
|
||||
dataBytes, err := json.Marshal(hookCtx.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
|
||||
}
|
||||
|
||||
// Update record
|
||||
query := h.db.NewUpdate().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to update record: %w", err)
|
||||
}
|
||||
|
||||
// Fetch updated record
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
// delete deletes a record
|
||||
func (h *Handler) delete(hookCtx *HookContext) error {
|
||||
query := h.db.NewDelete().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return fmt.Errorf("failed to delete record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetadata returns schema metadata for an entity
|
||||
func (h *Handler) getMetadata(hookCtx *HookContext) (interface{}, error) {
|
||||
metadata := make(map[string]interface{})
|
||||
metadata["schema"] = hookCtx.Schema
|
||||
metadata["entity"] = hookCtx.Entity
|
||||
metadata["table_name"] = hookCtx.TableName
|
||||
|
||||
// Get fields from model using reflection
|
||||
columns := reflection.GetModelColumns(hookCtx.Model)
|
||||
metadata["columns"] = columns
|
||||
metadata["primary_key"] = reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// getOperatorSQL converts filter operator to SQL operator
|
||||
func (h *Handler) getOperatorSQL(operator string) string {
|
||||
switch operator {
|
||||
case "eq":
|
||||
return "="
|
||||
case "neq":
|
||||
return "!="
|
||||
case "gt":
|
||||
return ">"
|
||||
case "gte":
|
||||
return ">="
|
||||
case "lt":
|
||||
return "<"
|
||||
case "lte":
|
||||
return "<="
|
||||
case "like":
|
||||
return "LIKE"
|
||||
case "ilike":
|
||||
return "ILIKE"
|
||||
case "in":
|
||||
return "IN"
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
743
pkg/mqttspec/handler_test.go
Normal file
743
pkg/mqttspec/handler_test.go
Normal file
@@ -0,0 +1,743 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Test model
|
||||
type TestUser struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (TestUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// setupTestHandler creates a handler with in-memory SQLite database
|
||||
func setupTestHandler(t *testing.T) (*Handler, *gorm.DB) {
|
||||
// Create in-memory SQLite database
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Auto-migrate test model
|
||||
err = db.AutoMigrate(&TestUser{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create handler
|
||||
config := DefaultConfig()
|
||||
config.Broker.Port = 21883 // Use different port for handler tests
|
||||
|
||||
adapter := database.NewGormAdapter(db)
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", &TestUser{})
|
||||
|
||||
handler, err := NewHandlerWithDatabase(adapter, registry, WithEmbeddedBroker(config.Broker))
|
||||
require.NoError(t, err)
|
||||
|
||||
return handler, db
|
||||
}
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.NotNil(t, handler.db)
|
||||
assert.NotNil(t, handler.registry)
|
||||
assert.NotNil(t, handler.hooks)
|
||||
assert.NotNil(t, handler.clientManager)
|
||||
assert.NotNil(t, handler.subscriptionManager)
|
||||
assert.NotNil(t, handler.broker)
|
||||
assert.NotNil(t, handler.config)
|
||||
}
|
||||
|
||||
func TestHandler_StartShutdown(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, handler.started)
|
||||
|
||||
// Shutdown handler
|
||||
err = handler.Shutdown()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, handler.started)
|
||||
}
|
||||
|
||||
func TestHandler_HandleRead_Single(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request message
|
||||
msg := &Message{
|
||||
ID: "msg-1",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
|
||||
// Note: In a full integration test, we would verify the response was published
|
||||
// to the correct MQTT topic. Here we're just testing that the handler doesn't error.
|
||||
}
|
||||
|
||||
func TestHandler_HandleRead_Multiple(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
users := []TestUser{
|
||||
{ID: 1, Name: "User 1", Email: "user1@example.com", Status: "active"},
|
||||
{ID: 2, Name: "User 2", Email: "user2@example.com", Status: "active"},
|
||||
{ID: 3, Name: "User 3", Email: "user3@example.com", Status: "inactive"},
|
||||
}
|
||||
for _, user := range users {
|
||||
db.Create(&user)
|
||||
}
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request with filter
|
||||
msg := &Message{
|
||||
ID: "msg-2",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler to initialize broker
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create request data
|
||||
newUser := map[string]interface{}{
|
||||
"name": "New User",
|
||||
"email": "new@example.com",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
// Create create request message
|
||||
msg := &Message{
|
||||
ID: "msg-3",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle create
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
|
||||
// Verify user was created in database
|
||||
var user TestUser
|
||||
result := db.Where("email = ?", "new@example.com").First(&user)
|
||||
assert.NoError(t, result.Error)
|
||||
assert.Equal(t, "New User", user.Name)
|
||||
}
|
||||
|
||||
func TestHandler_HandleUpdate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "Original Name",
|
||||
Email: "original@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Update data
|
||||
updateData := map[string]interface{}{
|
||||
"name": "Updated Name",
|
||||
}
|
||||
|
||||
// Create update request message
|
||||
msg := &Message{
|
||||
ID: "msg-4",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationUpdate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: updateData,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Data: updateData,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle update
|
||||
handler.handleUpdate(client, msg, hookCtx)
|
||||
|
||||
// Verify user was updated
|
||||
var updatedUser TestUser
|
||||
db.First(&updatedUser, 1)
|
||||
assert.Equal(t, "Updated Name", updatedUser.Name)
|
||||
}
|
||||
|
||||
func TestHandler_HandleDelete(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "To Delete",
|
||||
Email: "delete@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create delete request message
|
||||
msg := &Message{
|
||||
ID: "msg-5",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationDelete,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle delete
|
||||
handler.handleDelete(client, msg, hookCtx)
|
||||
|
||||
// Verify user was deleted
|
||||
var deletedUser TestUser
|
||||
result := db.First(&deletedUser, 1)
|
||||
assert.Error(t, result.Error)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, result.Error)
|
||||
}
|
||||
|
||||
func TestHandler_HandleSubscribe(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create subscribe message
|
||||
msg := &Message{
|
||||
ID: "msg-6",
|
||||
Type: MessageTypeSubscription,
|
||||
Operation: OperationSubscribe,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Handle subscribe
|
||||
handler.handleSubscribe(client, msg)
|
||||
|
||||
// Verify subscription was created
|
||||
subscriptions := handler.subscriptionManager.GetSubscriptionsByEntity("public", "users")
|
||||
assert.Len(t, subscriptions, 1)
|
||||
assert.Equal(t, client.ID, subscriptions[0].ConnectionID)
|
||||
}
|
||||
|
||||
func TestHandler_HandleUnsubscribe(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create subscription using Subscribe method
|
||||
sub := handler.subscriptionManager.Subscribe("sub-1", client.ID, "public", "users", &common.RequestOptions{})
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Create unsubscribe message with subscription ID in Data
|
||||
msg := &Message{
|
||||
ID: "msg-7",
|
||||
Type: MessageTypeSubscription,
|
||||
Operation: OperationUnsubscribe,
|
||||
Data: map[string]interface{}{"subscription_id": "sub-1"},
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Handle unsubscribe
|
||||
handler.handleUnsubscribe(client, msg)
|
||||
|
||||
// Verify subscription was removed
|
||||
_, exists := handler.subscriptionManager.GetSubscription("sub-1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestHandler_NotifySubscribers(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock clients
|
||||
client1 := handler.clientManager.Register("client-1", "user1", handler)
|
||||
client2 := handler.clientManager.Register("client-2", "user2", handler)
|
||||
|
||||
// Create subscriptions
|
||||
opts1 := &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
}
|
||||
sub1 := handler.subscriptionManager.Subscribe("sub-1", client1.ID, "public", "users", opts1)
|
||||
client1.AddSubscription(sub1)
|
||||
|
||||
opts2 := &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "inactive"},
|
||||
},
|
||||
}
|
||||
sub2 := handler.subscriptionManager.Subscribe("sub-2", client2.ID, "public", "users", opts2)
|
||||
client2.AddSubscription(sub2)
|
||||
|
||||
// Notify subscribers with active user
|
||||
activeUser := map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "Active User",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
// This should notify sub-1 only
|
||||
handler.notifySubscribers("public", "users", OperationCreate, activeUser)
|
||||
|
||||
// Note: In a full integration test, we would verify that the notification
|
||||
// was published to the correct MQTT topic. Here we're just testing that
|
||||
// the handler doesn't error and finds the correct subscriptions.
|
||||
}
|
||||
|
||||
func TestHandler_Hooks_BeforeRead(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data with different tenants
|
||||
users := []TestUser{
|
||||
{ID: 1, Name: "User 1", TenantID: "tenant-a", Status: "active"},
|
||||
{ID: 2, Name: "User 2", TenantID: "tenant-b", Status: "active"},
|
||||
{ID: 3, Name: "User 3", TenantID: "tenant-a", Status: "active"},
|
||||
}
|
||||
for _, user := range users {
|
||||
db.Create(&user)
|
||||
}
|
||||
|
||||
// Register hook to filter by tenant
|
||||
handler.Hooks().Register(BeforeRead, func(ctx *HookContext) error {
|
||||
// Auto-inject tenant filter
|
||||
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
|
||||
Column: "tenant_id",
|
||||
Operator: "eq",
|
||||
Value: "tenant-a",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request (no tenant filter)
|
||||
msg := &Message{
|
||||
ID: "msg-8",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
|
||||
// The hook should have injected the tenant filter
|
||||
// In a full test, we would verify only tenant-a users were returned
|
||||
}
|
||||
|
||||
func TestHandler_Hooks_BeforeCreate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Register hook to set default values
|
||||
handler.Hooks().Register(BeforeCreate, func(ctx *HookContext) error {
|
||||
// Auto-set tenant_id
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["tenant_id"] = "auto-tenant"
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create user without tenant_id
|
||||
newUser := map[string]interface{}{
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
ID: "msg-9",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle create
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
|
||||
// Verify tenant_id was auto-set
|
||||
var user TestUser
|
||||
db.Where("email = ?", "test@example.com").First(&user)
|
||||
assert.Equal(t, "auto-tenant", user.TenantID)
|
||||
}
|
||||
|
||||
func TestHandler_ConcurrentRequests(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create multiple clients
|
||||
var wg sync.WaitGroup
|
||||
numClients := 10
|
||||
|
||||
for i := 0; i < numClients; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
client := NewClient(fmt.Sprintf("client-%d", id), fmt.Sprintf("user%d", id), handler)
|
||||
|
||||
// Create user
|
||||
newUser := map[string]interface{}{
|
||||
"name": fmt.Sprintf("User %d", id),
|
||||
"email": fmt.Sprintf("user%d@example.com", id),
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
ID: fmt.Sprintf("msg-%d", id),
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all users were created
|
||||
var count int64
|
||||
db.Model(&TestUser{}).Count(&count)
|
||||
assert.Equal(t, int64(numClients), count)
|
||||
}
|
||||
|
||||
func TestHandler_TopicHelpers(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
clientID := "test-client"
|
||||
subscriptionID := "sub-123"
|
||||
|
||||
requestTopic := handler.getRequestTopic(clientID)
|
||||
assert.Equal(t, "spec/test-client/request", requestTopic)
|
||||
|
||||
responseTopic := handler.getResponseTopic(clientID)
|
||||
assert.Equal(t, "spec/test-client/response", responseTopic)
|
||||
|
||||
notifyTopic := handler.getNotifyTopic(clientID, subscriptionID)
|
||||
assert.Equal(t, "spec/test-client/notify/sub-123", notifyTopic)
|
||||
}
|
||||
|
||||
func TestHandler_SendResponse(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Test data
|
||||
clientID := "test-client"
|
||||
msgID := "msg-123"
|
||||
data := map[string]interface{}{"id": 1, "name": "Test"}
|
||||
metadata := map[string]interface{}{"total": 1}
|
||||
|
||||
// Send response (should not error)
|
||||
handler.sendResponse(clientID, msgID, data, metadata)
|
||||
}
|
||||
|
||||
func TestHandler_SendError(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Test error
|
||||
clientID := "test-client"
|
||||
msgID := "msg-123"
|
||||
code := "test_error"
|
||||
message := "Test error message"
|
||||
|
||||
// Send error (should not error)
|
||||
handler.sendError(clientID, msgID, code, message)
|
||||
}
|
||||
|
||||
// extractClientID extracts the client ID from a topic like spec/{client_id}/request
|
||||
func extractClientID(topic string) string {
|
||||
parts := strings.Split(topic, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[len(parts)-2]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestHandler_ExtractClientID(t *testing.T) {
|
||||
tests := []struct {
|
||||
topic string
|
||||
expected string
|
||||
}{
|
||||
{"spec/client-123/request", "client-123"},
|
||||
{"spec/abc-xyz/request", "abc-xyz"},
|
||||
{"spec/test/request", "test"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := extractClientID(tt.topic)
|
||||
assert.Equal(t, tt.expected, result, "topic: %s", tt.topic)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleIncomingMessage_InvalidJSON(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Invalid JSON payload
|
||||
payload := []byte("{invalid json")
|
||||
|
||||
// Should not panic
|
||||
handler.handleIncomingMessage("spec/test-client/request", payload)
|
||||
}
|
||||
|
||||
func TestHandler_HandleIncomingMessage_ValidMessage(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Valid message
|
||||
msg := &Message{
|
||||
ID: "msg-1",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(msg)
|
||||
|
||||
// Should not panic or error
|
||||
handler.handleIncomingMessage("spec/test-client/request", payload)
|
||||
}
|
||||
51
pkg/mqttspec/hooks.go
Normal file
51
pkg/mqttspec/hooks.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
|
||||
)
|
||||
|
||||
// Hook types - aliases to websocketspec for lifecycle hook consistency
|
||||
type (
|
||||
// HookType defines the type of lifecycle hook
|
||||
HookType = websocketspec.HookType
|
||||
|
||||
// HookFunc is a function that executes during a lifecycle hook
|
||||
HookFunc = websocketspec.HookFunc
|
||||
|
||||
// HookContext contains all context for hook execution
|
||||
// Note: For MQTT, the Client is stored in Metadata["mqtt_client"]
|
||||
HookContext = websocketspec.HookContext
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
HookRegistry = websocketspec.HookRegistry
|
||||
)
|
||||
|
||||
// Hook type constants - all 12 lifecycle hooks
|
||||
const (
|
||||
// CRUD operation hooks
|
||||
BeforeRead = websocketspec.BeforeRead
|
||||
AfterRead = websocketspec.AfterRead
|
||||
BeforeCreate = websocketspec.BeforeCreate
|
||||
AfterCreate = websocketspec.AfterCreate
|
||||
BeforeUpdate = websocketspec.BeforeUpdate
|
||||
AfterUpdate = websocketspec.AfterUpdate
|
||||
BeforeDelete = websocketspec.BeforeDelete
|
||||
AfterDelete = websocketspec.AfterDelete
|
||||
|
||||
// Subscription hooks
|
||||
BeforeSubscribe = websocketspec.BeforeSubscribe
|
||||
AfterSubscribe = websocketspec.AfterSubscribe
|
||||
BeforeUnsubscribe = websocketspec.BeforeUnsubscribe
|
||||
AfterUnsubscribe = websocketspec.AfterUnsubscribe
|
||||
|
||||
// Connection hooks
|
||||
BeforeConnect = websocketspec.BeforeConnect
|
||||
AfterConnect = websocketspec.AfterConnect
|
||||
BeforeDisconnect = websocketspec.BeforeDisconnect
|
||||
AfterDisconnect = websocketspec.AfterDisconnect
|
||||
)
|
||||
|
||||
// NewHookRegistry creates a new hook registry
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return websocketspec.NewHookRegistry()
|
||||
}
|
||||
63
pkg/mqttspec/message.go
Normal file
63
pkg/mqttspec/message.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
|
||||
)
|
||||
|
||||
// Message types - aliases to websocketspec for protocol consistency
|
||||
type (
|
||||
// Message represents an MQTT message (identical to WebSocket message protocol)
|
||||
Message = websocketspec.Message
|
||||
|
||||
// MessageType defines the type of message
|
||||
MessageType = websocketspec.MessageType
|
||||
|
||||
// OperationType defines the operation to perform
|
||||
OperationType = websocketspec.OperationType
|
||||
|
||||
// ResponseMessage is sent back to clients after processing requests
|
||||
ResponseMessage = websocketspec.ResponseMessage
|
||||
|
||||
// NotificationMessage is sent to subscribers when data changes
|
||||
NotificationMessage = websocketspec.NotificationMessage
|
||||
|
||||
// ErrorInfo contains error details
|
||||
ErrorInfo = websocketspec.ErrorInfo
|
||||
)
|
||||
|
||||
// Message type constants
|
||||
const (
|
||||
MessageTypeRequest = websocketspec.MessageTypeRequest
|
||||
MessageTypeResponse = websocketspec.MessageTypeResponse
|
||||
MessageTypeNotification = websocketspec.MessageTypeNotification
|
||||
MessageTypeSubscription = websocketspec.MessageTypeSubscription
|
||||
MessageTypeError = websocketspec.MessageTypeError
|
||||
MessageTypePing = websocketspec.MessageTypePing
|
||||
MessageTypePong = websocketspec.MessageTypePong
|
||||
)
|
||||
|
||||
// Operation type constants
|
||||
const (
|
||||
OperationRead = websocketspec.OperationRead
|
||||
OperationCreate = websocketspec.OperationCreate
|
||||
OperationUpdate = websocketspec.OperationUpdate
|
||||
OperationDelete = websocketspec.OperationDelete
|
||||
OperationSubscribe = websocketspec.OperationSubscribe
|
||||
OperationUnsubscribe = websocketspec.OperationUnsubscribe
|
||||
OperationMeta = websocketspec.OperationMeta
|
||||
)
|
||||
|
||||
// Helper functions from websocketspec
|
||||
var (
|
||||
// NewResponseMessage creates a new response message
|
||||
NewResponseMessage = websocketspec.NewResponseMessage
|
||||
|
||||
// NewErrorResponse creates an error response
|
||||
NewErrorResponse = websocketspec.NewErrorResponse
|
||||
|
||||
// NewNotificationMessage creates a notification message
|
||||
NewNotificationMessage = websocketspec.NewNotificationMessage
|
||||
|
||||
// ParseMessage parses a JSON message into a Message struct
|
||||
ParseMessage = websocketspec.ParseMessage
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user