mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-02-05 01:04:25 +00:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50d0caabc2 | ||
|
|
5269ae4de2 | ||
|
|
646620ed83 | ||
| 7600a6d1fb | |||
| 2e7b3e7abd | |||
| fdf9e118c5 | |||
| e11e6a8bf7 | |||
| 261f98eb29 | |||
| 0b8d11361c | |||
|
|
e70bab92d7 | ||
|
|
fc8f44e3e8 | ||
|
|
584bb9813d | ||
|
|
17239d1611 | ||
|
|
defe27549b | ||
|
|
f7725340a6 | ||
|
|
07016d1b73 | ||
|
|
09f2256899 | ||
|
|
c12c045db1 | ||
|
|
24a7ef7284 | ||
|
|
b87841a51c | ||
|
|
289cd74485 | ||
|
|
c75842ebb0 | ||
|
|
7879272dda | ||
|
|
292306b608 | ||
|
|
a980201d21 | ||
|
|
276854768e | ||
|
|
cf6a81e805 | ||
|
|
0ac207d80f | ||
|
|
b7a67a6974 | ||
|
|
cb20a354fc | ||
|
|
37c85361ba | ||
|
|
a7e640a6a1 | ||
|
|
bf7125efc3 | ||
|
|
e220ab3d34 | ||
|
|
6a0297713a | ||
|
|
6ea200bb2b | ||
|
|
987244019c | ||
|
|
62a8e56f1b | ||
|
|
d8df1bdac2 | ||
|
|
c0c669bd3d | ||
| 0cc3635466 | |||
| c2d86c9880 | |||
| 70bf0a4be1 | |||
| 4964d89158 | |||
| 96b098f912 | |||
| 5bba99efe3 | |||
| 8504b6d13d | |||
| ada4db6465 | |||
| 2017465cb8 | |||
| d33747c2d3 | |||
| c864aa4d90 | |||
| 250fcf686c | |||
| 47cfc4b3da | |||
| 0e8ae75daf | |||
| ce092d1c62 | |||
| 871dd2e374 | |||
|
|
ebd03d10ad | ||
|
|
4ee6ef0955 | ||
|
|
6f05f15ff6 | ||
|
|
443a672fcb | ||
|
|
c2fcc5aaff | ||
|
|
6664a4e2d2 | ||
| 037bd4c05e | |||
|
|
e77468a239 | ||
|
|
82d84435f2 | ||
|
|
b99b08430e | ||
|
|
fae9a082bd | ||
|
|
191822b91c | ||
|
|
a6a17d019f | ||
|
|
a7cc42044b | ||
|
|
8cdc353029 | ||
|
|
6528e94297 | ||
|
|
f711bf38d2 | ||
|
|
44356d8750 | ||
|
|
caf85cf558 | ||
|
|
2e1547ec65 | ||
|
|
49cdc6f17b | ||
|
|
0bd653820c | ||
|
|
9209193157 | ||
|
|
b8c44c5a99 | ||
|
|
28fd88fff1 | ||
|
|
be38341383 | ||
|
|
fab744b878 | ||
|
|
5ad2bd3a78 | ||
|
|
333fe158e9 | ||
|
|
2a2d351ad4 | ||
|
|
e918c49b84 |
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -17,11 +17,13 @@ jobs:
|
|||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
|
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
|
||||||
- name: Generate coverage report
|
- name: Generate coverage report
|
||||||
|
continue-on-error: true
|
||||||
run: |
|
run: |
|
||||||
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
|
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
|
||||||
go tool cover -html=coverage.out -o coverage.html
|
go tool cover -html=coverage.out -o coverage.html
|
||||||
- name: Upload coverage
|
- name: Upload coverage
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
name: coverage-report
|
name: coverage-report
|
||||||
path: coverage.html
|
path: coverage.html
|
||||||
@@ -55,27 +57,34 @@ jobs:
|
|||||||
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
|
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
|
||||||
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
|
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
|
||||||
- name: Run resolvespec integration tests
|
- name: Run resolvespec integration tests
|
||||||
|
continue-on-error: true
|
||||||
env:
|
env:
|
||||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||||
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
|
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
|
||||||
- name: Run restheadspec integration tests
|
- name: Run restheadspec integration tests
|
||||||
|
continue-on-error: true
|
||||||
env:
|
env:
|
||||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
|
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
|
||||||
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
|
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
|
||||||
- name: Generate integration coverage
|
- name: Generate integration coverage
|
||||||
|
continue-on-error: true
|
||||||
env:
|
env:
|
||||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||||
run: |
|
run: |
|
||||||
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
|
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
|
||||||
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
|
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
|
||||||
|
|
||||||
- name: Upload resolvespec integration coverage
|
- name: Upload resolvespec integration coverage
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
name: resolvespec-integration-coverage-report
|
name: resolvespec-integration-coverage-report
|
||||||
path: coverage-resolvespec-integration.html
|
path: coverage-resolvespec-integration.html
|
||||||
|
|
||||||
- name: Upload restheadspec integration coverage
|
- name: Upload restheadspec integration coverage
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
with:
|
with:
|
||||||
name: integration-coverage-restheadspec-report
|
name: integration-coverage-restheadspec-report
|
||||||
path: coverage-restheadspec-integration
|
path: coverage-restheadspec-integration
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -25,4 +25,5 @@ go.work.sum
|
|||||||
.env
|
.env
|
||||||
bin/
|
bin/
|
||||||
test.db
|
test.db
|
||||||
testserver
|
/testserver
|
||||||
|
tests/data/
|
||||||
6
.vscode/settings.json
vendored
6
.vscode/settings.json
vendored
@@ -52,5 +52,9 @@
|
|||||||
"upgrade_dependency": true,
|
"upgrade_dependency": true,
|
||||||
"vendor": true
|
"vendor": true
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"conventionalCommits.scopes": [
|
||||||
|
"spectypes",
|
||||||
|
"dbmanager"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
86
LICENSE
86
LICENSE
@@ -1,21 +1,73 @@
|
|||||||
MIT License
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
Copyright (c) 2025
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
1. Definitions.
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2025 wdevs
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|||||||
18
Makefile
18
Makefile
@@ -13,15 +13,23 @@ test-integration:
|
|||||||
# Run all tests (unit + integration)
|
# Run all tests (unit + integration)
|
||||||
test: test-unit test-integration
|
test: test-unit test-integration
|
||||||
|
|
||||||
release-version: ## Create and push a release with specific version (use: make release-version VERSION=v1.2.3)
|
release-version: ## Create and push a release with specific version (use: make release-version VERSION=v1.2.3 or make release-version to auto-increment)
|
||||||
@if [ -z "$(VERSION)" ]; then \
|
@if [ -z "$(VERSION)" ]; then \
|
||||||
echo "Error: VERSION is required. Usage: make release-version VERSION=v1.2.3"; \
|
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0"); \
|
||||||
exit 1; \
|
echo "No VERSION specified. Last version: $$latest_tag"; \
|
||||||
fi
|
version_num=$$(echo "$$latest_tag" | sed 's/^v//'); \
|
||||||
@version="$(VERSION)"; \
|
major=$$(echo "$$version_num" | cut -d. -f1); \
|
||||||
|
minor=$$(echo "$$version_num" | cut -d. -f2); \
|
||||||
|
patch=$$(echo "$$version_num" | cut -d. -f3); \
|
||||||
|
new_patch=$$((patch + 1)); \
|
||||||
|
version="v$$major.$$minor.$$new_patch"; \
|
||||||
|
echo "Auto-incrementing to: $$version"; \
|
||||||
|
else \
|
||||||
|
version="$(VERSION)"; \
|
||||||
if ! echo "$$version" | grep -q "^v"; then \
|
if ! echo "$$version" | grep -q "^v"; then \
|
||||||
version="v$$version"; \
|
version="v$$version"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
fi; \
|
||||||
echo "Creating release: $$version"; \
|
echo "Creating release: $$version"; \
|
||||||
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo ""); \
|
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo ""); \
|
||||||
if [ -z "$$latest_tag" ]; then \
|
if [ -z "$$latest_tag" ]; then \
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
@@ -15,7 +17,6 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
gormlog "gorm.io/gorm/logger"
|
gormlog "gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
@@ -38,14 +39,15 @@ func main() {
|
|||||||
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
|
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
|
||||||
}
|
}
|
||||||
logger.Info("ResolveSpec test server starting")
|
logger.Info("ResolveSpec test server starting")
|
||||||
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
|
|
||||||
|
|
||||||
// Initialize database
|
// Initialize database manager
|
||||||
db, err := initDB(cfg)
|
ctx := context.Background()
|
||||||
|
dbMgr, db, err := initDB(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to initialize database: %+v", err)
|
logger.Error("Failed to initialize database: %+v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
defer dbMgr.Close()
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
@@ -70,54 +72,37 @@ func main() {
|
|||||||
// Create server manager
|
// Create server manager
|
||||||
mgr := server.NewManager()
|
mgr := server.NewManager()
|
||||||
|
|
||||||
// Parse host and port from addr
|
// Get default server configuration
|
||||||
host := ""
|
defaultServerCfg, err := cfg.Servers.GetDefault()
|
||||||
port := 8080
|
|
||||||
if cfg.Server.Addr != "" {
|
|
||||||
// Parse addr (format: ":8080" or "localhost:8080")
|
|
||||||
if cfg.Server.Addr[0] == ':' {
|
|
||||||
// Just port
|
|
||||||
_, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
logger.Error("Failed to get default server config: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Host and port
|
|
||||||
_, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add server instance
|
// Apply global defaults
|
||||||
_, err = mgr.Add(server.Config{
|
defaultServerCfg.ApplyGlobalDefaults(cfg.Servers)
|
||||||
Name: "api",
|
|
||||||
Host: host,
|
// Convert to server.Config and add instance
|
||||||
Port: port,
|
serverCfg := server.FromConfigInstanceToServerConfig(defaultServerCfg, r)
|
||||||
Handler: r,
|
|
||||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
logger.Info("Configuration loaded - Server '%s' will listen on %s:%d",
|
||||||
DrainTimeout: cfg.Server.DrainTimeout,
|
serverCfg.Name, serverCfg.Host, serverCfg.Port)
|
||||||
ReadTimeout: cfg.Server.ReadTimeout,
|
|
||||||
WriteTimeout: cfg.Server.WriteTimeout,
|
_, err = mgr.Add(serverCfg)
|
||||||
IdleTimeout: cfg.Server.IdleTimeout,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to add server: %v", err)
|
logger.Error("Failed to add server: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start server with graceful shutdown
|
// Start server with graceful shutdown
|
||||||
logger.Info("Starting server on %s", cfg.Server.Addr)
|
logger.Info("Starting server '%s' on %s:%d", serverCfg.Name, serverCfg.Host, serverCfg.Port)
|
||||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||||
logger.Error("Server failed: %v", err)
|
logger.Error("Server failed: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
func initDB(ctx context.Context, cfg *config.Config) (dbmanager.Manager, *gorm.DB, error) {
|
||||||
// Configure GORM logger based on config
|
// Configure GORM logger based on config
|
||||||
logLevel := gormlog.Info
|
logLevel := gormlog.Info
|
||||||
if !cfg.Logger.Dev {
|
if !cfg.Logger.Dev {
|
||||||
@@ -135,25 +120,41 @@ func initDB(cfg *config.Config) (*gorm.DB, error) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
// Use database URL from config if available, otherwise use default SQLite
|
// Create database manager from config
|
||||||
dbURL := cfg.Database.URL
|
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
|
||||||
if dbURL == "" {
|
if err != nil {
|
||||||
dbURL = "test.db"
|
return nil, nil, fmt.Errorf("failed to create database manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create SQLite database
|
// Connect all databases
|
||||||
db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
|
if err := mgr.Connect(ctx); err != nil {
|
||||||
if err != nil {
|
return nil, nil, fmt.Errorf("failed to connect databases: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get default connection
|
||||||
|
conn, err := mgr.GetDefault()
|
||||||
|
if err != nil {
|
||||||
|
mgr.Close()
|
||||||
|
return nil, nil, fmt.Errorf("failed to get default connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get GORM database
|
||||||
|
gormDB, err := conn.GORM()
|
||||||
|
if err != nil {
|
||||||
|
mgr.Close()
|
||||||
|
return nil, nil, fmt.Errorf("failed to get GORM database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update GORM logger
|
||||||
|
gormDB.Logger = newLogger
|
||||||
|
|
||||||
modelList := testmodels.GetTestModels()
|
modelList := testmodels.GetTestModels()
|
||||||
|
|
||||||
// Auto migrate schemas
|
// Auto migrate schemas
|
||||||
err = db.AutoMigrate(modelList...)
|
if err := gormDB.AutoMigrate(modelList...); err != nil {
|
||||||
if err != nil {
|
mgr.Close()
|
||||||
return nil, err
|
return nil, nil, fmt.Errorf("failed to auto migrate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db, nil
|
return mgr, gormDB, nil
|
||||||
}
|
}
|
||||||
|
|||||||
24
config.yaml
24
config.yaml
@@ -37,5 +37,25 @@ cors:
|
|||||||
tracing:
|
tracing:
|
||||||
enabled: false
|
enabled: false
|
||||||
|
|
||||||
database:
|
# Database Manager Configuration
|
||||||
url: "" # Empty means use default SQLite (test.db)
|
dbmanager:
|
||||||
|
default_connection: "primary"
|
||||||
|
max_open_conns: 25
|
||||||
|
max_idle_conns: 5
|
||||||
|
conn_max_lifetime: 30m
|
||||||
|
conn_max_idle_time: 5m
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 1s
|
||||||
|
health_check_interval: 30s
|
||||||
|
enable_auto_reconnect: true
|
||||||
|
|
||||||
|
connections:
|
||||||
|
primary:
|
||||||
|
name: "primary"
|
||||||
|
type: "sqlite"
|
||||||
|
filepath: "test.db"
|
||||||
|
default_orm: "gorm"
|
||||||
|
enable_logging: true
|
||||||
|
enable_metrics: false
|
||||||
|
connect_timeout: 10s
|
||||||
|
query_timeout: 30s
|
||||||
|
|||||||
63
go.mod
63
go.mod
@@ -13,32 +13,38 @@ require (
|
|||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/jackc/pgx/v5 v5.6.0
|
github.com/jackc/pgx/v5 v5.8.0
|
||||||
github.com/klauspost/compress v1.18.0
|
github.com/klauspost/compress v1.18.2
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.33
|
||||||
|
github.com/microsoft/go-mssqldb v1.9.5
|
||||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||||
github.com/nats-io/nats.go v1.48.0
|
github.com/nats-io/nats.go v1.48.0
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/redis/go-redis/v9 v9.17.1
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
github.com/spf13/viper v1.21.0
|
github.com/spf13/viper v1.21.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/testcontainers/testcontainers-go v0.40.0
|
github.com/testcontainers/testcontainers-go v0.40.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/uptrace/bun v1.2.16
|
github.com/uptrace/bun v1.2.16
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
||||||
github.com/uptrace/bunrouter v1.0.23
|
github.com/uptrace/bunrouter v1.0.23
|
||||||
|
go.mongodb.org/mongo-driver v1.17.6
|
||||||
go.opentelemetry.io/otel v1.38.0
|
go.opentelemetry.io/otel v1.38.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
||||||
go.opentelemetry.io/otel/sdk v1.38.0
|
go.opentelemetry.io/otel/sdk v1.38.0
|
||||||
go.opentelemetry.io/otel/trace v1.38.0
|
go.opentelemetry.io/otel/trace v1.38.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.1
|
||||||
golang.org/x/crypto v0.43.0
|
golang.org/x/crypto v0.46.0
|
||||||
golang.org/x/time v0.14.0
|
golang.org/x/time v0.14.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/gorm v1.30.0
|
gorm.io/driver/sqlserver v1.6.3
|
||||||
|
gorm.io/gorm v1.31.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -64,11 +70,14 @@ require (
|
|||||||
github.com/ebitengine/purego v0.8.4 // indirect
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
|
github.com/golang/snappy v1.0.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
@@ -78,7 +87,6 @@ require (
|
|||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/magiconair/properties v1.8.10 // indirect
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
github.com/moby/go-archive v0.1.0 // indirect
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||||
@@ -86,6 +94,7 @@ require (
|
|||||||
github.com/moby/sys/user v0.4.0 // indirect
|
github.com/moby/sys/user v0.4.0 // indirect
|
||||||
github.com/moby/sys/userns v0.1.0 // indirect
|
github.com/moby/sys/userns v0.1.0 // indirect
|
||||||
github.com/moby/term v0.5.0 // indirect
|
github.com/moby/term v0.5.0 // indirect
|
||||||
|
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||||
github.com/morikuni/aec v1.0.0 // indirect
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nats-io/nkeys v0.4.11 // indirect
|
github.com/nats-io/nkeys v0.4.11 // indirect
|
||||||
@@ -98,49 +107,55 @@ require (
|
|||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
github.com/prometheus/common v0.66.1 // indirect
|
github.com/prometheus/common v0.67.4 // indirect
|
||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.19.2 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/rs/xid v1.4.0 // indirect
|
github.com/rs/xid v1.4.0 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
|
||||||
github.com/spf13/afero v1.15.0 // indirect
|
github.com/spf13/afero v1.15.0 // indirect
|
||||||
github.com/spf13/cast v1.10.0 // indirect
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.10 // indirect
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
|
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||||
|
github.com/xdg-go/scram v1.2.0 // indirect
|
||||||
|
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||||
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||||
go.uber.org/multierr v1.10.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||||
golang.org/x/net v0.45.0 // indirect
|
golang.org/x/mod v0.31.0 // indirect
|
||||||
golang.org/x/sync v0.18.0 // indirect
|
golang.org/x/net v0.48.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/oauth2 v0.34.0 // indirect
|
||||||
golang.org/x/text v0.30.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
|
golang.org/x/text v0.32.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||||
google.golang.org/grpc v1.75.0 // indirect
|
google.golang.org/grpc v1.75.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.8 // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
modernc.org/libc v1.67.0 // indirect
|
modernc.org/libc v1.67.4 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
modernc.org/sqlite v1.40.1 // indirect
|
modernc.org/sqlite v1.42.2 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||||
|
|||||||
296
go.sum
296
go.sum
@@ -2,8 +2,32 @@ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
|||||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||||
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
@@ -32,6 +56,7 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
|
|||||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
@@ -41,6 +66,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
|
||||||
|
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
@@ -61,8 +88,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
|
|||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||||
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
||||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
||||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||||
@@ -76,31 +103,54 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
|||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||||
|
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||||
|
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||||
|
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||||
|
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
|
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||||
|
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
|
||||||
|
github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo=
|
||||||
|
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
|
||||||
|
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
|
||||||
|
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
@@ -108,10 +158,13 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
|||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||||
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
@@ -122,8 +175,11 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
|
|||||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
||||||
|
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
|
||||||
|
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
|
||||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||||
@@ -142,6 +198,10 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
|||||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI=
|
github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI=
|
||||||
github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc=
|
github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc=
|
||||||
|
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
|
||||||
|
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||||
|
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||||
|
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
@@ -162,6 +222,10 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||||
|
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||||
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
@@ -172,28 +236,30 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
|
|||||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||||
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
|
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
|
||||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
@@ -203,10 +269,18 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
|
|||||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
@@ -216,10 +290,12 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3
|
|||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||||
|
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
@@ -228,6 +304,10 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F
|
|||||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16 h1:rKv0cKPNBviXadB/+2Y/UedA/c1JnwGzUWZkdN5FdSQ=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16/go.mod h1:J5U7tGKWDsx2Q7MwDZF2417jCdpD6yD/ZMFJcCR80bk=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
||||||
@@ -240,8 +320,19 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
|
|||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||||
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||||
|
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||||
|
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||||
|
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||||
|
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||||
|
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||||
|
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||||
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||||
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||||
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
|
||||||
|
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||||
@@ -266,41 +357,137 @@ go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOV
|
|||||||
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||||
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
|
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||||
|
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||||
|
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||||
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||||
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||||
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
|
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
|
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
|
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
|
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
|
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||||
|
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
|
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
|
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
|
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||||
|
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||||
|
golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||||
|
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||||
|
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||||
|
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||||
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
|
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||||
|
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||||
|
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||||
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||||
|
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||||
|
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||||
|
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||||
|
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||||
|
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||||
|
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
|
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||||
|
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||||
|
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||||
|
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||||
|
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||||
|
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||||
|
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||||
|
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||||
|
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||||
|
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||||
|
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||||
|
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||||
|
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
|
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||||
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
|
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
|
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
|
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
|
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
|
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||||
|
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||||
|
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||||
|
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
|
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||||
|
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||||
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
@@ -310,11 +497,15 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
||||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -322,8 +513,11 @@ gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
|||||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
|
||||||
|
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
|
||||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||||
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
|
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
@@ -338,8 +532,8 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
|||||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
||||||
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
@@ -348,8 +542,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
|
||||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
@@ -208,16 +208,10 @@ type BunSelectQuery struct {
|
|||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
|
||||||
inJoinContext bool // Track if we're in a JOIN relation context
|
inJoinContext bool // Track if we're in a JOIN relation context
|
||||||
joinTableAlias string // Alias to use for JOIN conditions
|
joinTableAlias string // Alias to use for JOIN conditions
|
||||||
}
|
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||||
|
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
||||||
// deferredPreload represents a preload that will be executed as a separate query
|
|
||||||
// to avoid PostgreSQL identifier length limits
|
|
||||||
type deferredPreload struct {
|
|
||||||
relation string
|
|
||||||
apply []func(common.SelectQuery) common.SelectQuery
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
@@ -486,51 +480,29 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
|
|
||||||
// // when combined with typical column names
|
|
||||||
// func shortenAliasForPostgres(relationPath string) (string, bool) {
|
|
||||||
// // Convert relation path to the alias format Bun uses: dots become double underscores
|
|
||||||
// // Also convert to lowercase and use snake_case as Bun does
|
|
||||||
// parts := strings.Split(relationPath, ".")
|
|
||||||
// alias := strings.ToLower(strings.Join(parts, "__"))
|
|
||||||
|
|
||||||
// // PostgreSQL truncates identifiers to 63 chars
|
|
||||||
// // If the alias + typical column name would exceed this, we need to shorten
|
|
||||||
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
|
|
||||||
// const maxAliasLength = 30
|
|
||||||
|
|
||||||
// if len(alias) > maxAliasLength {
|
|
||||||
// // Create a shortened alias using a hash of the original
|
|
||||||
// hash := md5.Sum([]byte(alias))
|
|
||||||
// hashStr := hex.EncodeToString(hash[:])[:8]
|
|
||||||
|
|
||||||
// // Keep first few chars of original for readability + hash
|
|
||||||
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
|
|
||||||
// if prefixLen > len(alias) {
|
|
||||||
// prefixLen = len(alias)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// shortened := alias[:prefixLen] + "_" + hashStr
|
|
||||||
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
|
|
||||||
// alias, len(alias), shortened, len(shortened))
|
|
||||||
// return shortened, true
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return alias, false
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
|
|
||||||
// // Bun creates aliases like: relationChain__columnName
|
|
||||||
// func estimateColumnAliasLength(relationPath string, columnName string) int {
|
|
||||||
// relationParts := strings.Split(relationPath, ".")
|
|
||||||
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
|
||||||
// // Bun adds "__" between alias and column name
|
|
||||||
// return len(aliasChain) + 2 + len(columnName)
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// Check if this relation will likely cause alias truncation FIRST
|
||||||
|
// PostgreSQL has a 63-character limit on identifiers
|
||||||
|
willTruncate := checkAliasLength(relation)
|
||||||
|
|
||||||
|
if willTruncate {
|
||||||
|
logger.Warn("Preload relation '%s' would generate aliases exceeding PostgreSQL's 63-char limit", relation)
|
||||||
|
logger.Info("Using custom preload implementation with separate queries for relation '%s'", relation)
|
||||||
|
|
||||||
|
// Store this relation for custom post-processing after the main query
|
||||||
|
// We'll load it manually with separate queries to avoid JOIN aliases
|
||||||
|
if b.customPreloads == nil {
|
||||||
|
b.customPreloads = make(map[string][]func(common.SelectQuery) common.SelectQuery)
|
||||||
|
}
|
||||||
|
b.customPreloads[relation] = apply
|
||||||
|
|
||||||
|
// Return without calling Bun's Relation() - we'll handle it ourselves
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Auto-detect relationship type and choose optimal loading strategy
|
// Auto-detect relationship type and choose optimal loading strategy
|
||||||
// Get the model from the query if available
|
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
|
||||||
|
if !b.skipAutoDetect {
|
||||||
model := b.query.GetModel()
|
model := b.query.GetModel()
|
||||||
if model != nil && model.Value() != nil {
|
if model != nil && model.Value() != nil {
|
||||||
relType := reflection.GetRelationType(model.Value(), relation)
|
relType := reflection.GetRelationType(model.Value(), relation)
|
||||||
@@ -538,8 +510,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
// Log the detected relationship type
|
// Log the detected relationship type
|
||||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||||
|
|
||||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
|
||||||
if relType.ShouldUseJoin() {
|
if relType.ShouldUseJoin() {
|
||||||
|
// If this is a belongs-to or has-one relation that won't exceed limits, use JOIN for better performance
|
||||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||||
return b.JoinRelation(relation, apply...)
|
return b.JoinRelation(relation, apply...)
|
||||||
}
|
}
|
||||||
@@ -549,50 +521,11 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this relation chain would create problematic long aliases
|
|
||||||
relationParts := strings.Split(relation, ".")
|
|
||||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
|
||||||
|
|
||||||
// PostgreSQL's identifier limit is 63 characters
|
|
||||||
const postgresIdentifierLimit = 63
|
|
||||||
const safeAliasLimit = 35 // Leave room for column names
|
|
||||||
|
|
||||||
// If the alias chain is too long, defer this preload to be executed as a separate query
|
|
||||||
if len(aliasChain) > safeAliasLimit {
|
|
||||||
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
|
|
||||||
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
|
|
||||||
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
|
|
||||||
|
|
||||||
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
|
|
||||||
// This avoids the long concatenated alias
|
|
||||||
if len(relationParts) > 1 {
|
|
||||||
// Load first level normally: "Parent"
|
|
||||||
firstLevel := relationParts[0]
|
|
||||||
remainingPath := strings.Join(relationParts[1:], ".")
|
|
||||||
|
|
||||||
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
|
|
||||||
firstLevel, remainingPath)
|
|
||||||
|
|
||||||
// Apply the first level preload normally
|
|
||||||
b.query = b.query.Relation(firstLevel)
|
|
||||||
|
|
||||||
// Store the remaining nested preload to be executed after the main query
|
|
||||||
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
|
|
||||||
relation: relation,
|
|
||||||
apply: apply,
|
|
||||||
})
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single level but still too long - just warn and continue
|
// Use Bun's native Relation() for preloading
|
||||||
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
|
// Note: For relations that would cause truncation, skipAutoDetect is set to true
|
||||||
"Consider renaming the field to avoid potential issues.",
|
// to prevent our auto-detection from adding JOIN optimization
|
||||||
relation, len(aliasChain))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normal preload handling
|
|
||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -625,14 +558,9 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
// Extract table alias if model implements TableAliasProvider
|
// Extract table alias if model implements TableAliasProvider
|
||||||
if provider, ok := modelValue.(common.TableAliasProvider); ok {
|
if provider, ok := modelValue.(common.TableAliasProvider); ok {
|
||||||
wrapper.tableAlias = provider.TableAlias()
|
wrapper.tableAlias = provider.TableAlias()
|
||||||
// Apply the alias to the Bun query so conditions can reference it
|
|
||||||
if wrapper.tableAlias != "" {
|
|
||||||
// Note: Bun's Relation() already sets up the table, but we can add
|
|
||||||
// the alias explicitly if needed
|
|
||||||
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Start with the interface value (not pointer)
|
// Start with the interface value (not pointer)
|
||||||
current := common.SelectQuery(wrapper)
|
current := common.SelectQuery(wrapper)
|
||||||
@@ -640,7 +568,6 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
// Apply each function in sequence
|
// Apply each function in sequence
|
||||||
for _, fn := range apply {
|
for _, fn := range apply {
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
// Pass ¤t (pointer to interface variable), fn modifies and returns new interface value
|
|
||||||
modified := fn(current)
|
modified := fn(current)
|
||||||
current = modified
|
current = modified
|
||||||
}
|
}
|
||||||
@@ -656,6 +583,502 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkIfRelationAlreadyLoaded checks if a relation is already populated on parent records
|
||||||
|
// Returns the collection of related records if already loaded
|
||||||
|
func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (reflect.Value, bool) {
|
||||||
|
if parents.Len() == 0 {
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first parent to check the relation field
|
||||||
|
firstParent := parents.Index(0)
|
||||||
|
if firstParent.Kind() == reflect.Ptr {
|
||||||
|
firstParent = firstParent.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the relation field
|
||||||
|
relationField := firstParent.FieldByName(relationName)
|
||||||
|
if !relationField.IsValid() {
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a slice (has-many)
|
||||||
|
if relationField.Kind() == reflect.Slice {
|
||||||
|
// Check if any parent has a non-empty slice
|
||||||
|
for i := 0; i < parents.Len(); i++ {
|
||||||
|
parent := parents.Index(i)
|
||||||
|
if parent.Kind() == reflect.Ptr {
|
||||||
|
parent = parent.Elem()
|
||||||
|
}
|
||||||
|
field := parent.FieldByName(relationName)
|
||||||
|
if field.IsValid() && !field.IsNil() && field.Len() > 0 {
|
||||||
|
// Already loaded! Collect all related records from all parents
|
||||||
|
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
|
||||||
|
for j := 0; j < parents.Len(); j++ {
|
||||||
|
p := parents.Index(j)
|
||||||
|
if p.Kind() == reflect.Ptr {
|
||||||
|
p = p.Elem()
|
||||||
|
}
|
||||||
|
f := p.FieldByName(relationName)
|
||||||
|
if f.IsValid() && !f.IsNil() {
|
||||||
|
for k := 0; k < f.Len(); k++ {
|
||||||
|
allRelated = reflect.Append(allRelated, f.Index(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allRelated, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if relationField.Kind() == reflect.Ptr {
|
||||||
|
// Check if it's a pointer (has-one/belongs-to)
|
||||||
|
if !relationField.IsNil() {
|
||||||
|
// Already loaded! Collect all related records from all parents
|
||||||
|
relatedType := relationField.Type()
|
||||||
|
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
|
||||||
|
for j := 0; j < parents.Len(); j++ {
|
||||||
|
p := parents.Index(j)
|
||||||
|
if p.Kind() == reflect.Ptr {
|
||||||
|
p = p.Elem()
|
||||||
|
}
|
||||||
|
f := p.FieldByName(relationName)
|
||||||
|
if f.IsValid() && !f.IsNil() {
|
||||||
|
allRelated = reflect.Append(allRelated, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allRelated, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadCustomPreloads loads relations that would cause alias truncation using separate queries
|
||||||
|
func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
|
||||||
|
model := b.query.GetModel()
|
||||||
|
if model == nil || model.Value() == nil {
|
||||||
|
return fmt.Errorf("no model to load preloads for")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the actual data from the model
|
||||||
|
modelValue := reflect.ValueOf(model.Value())
|
||||||
|
if modelValue.Kind() == reflect.Ptr {
|
||||||
|
modelValue = modelValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// We only handle slices of records for now
|
||||||
|
if modelValue.Kind() != reflect.Slice {
|
||||||
|
logger.Warn("Custom preloads only support slice models currently, got: %v", modelValue.Kind())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelValue.Len() == 0 {
|
||||||
|
logger.Debug("No records to load preloads for")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each custom preload relation
|
||||||
|
for relation, applyFuncs := range b.customPreloads {
|
||||||
|
logger.Info("Loading custom preload for relation: %s", relation)
|
||||||
|
|
||||||
|
// Parse the relation path (e.g., "MTL.MAL.DEF" -> ["MTL", "MAL", "DEF"])
|
||||||
|
relationParts := strings.Split(relation, ".")
|
||||||
|
|
||||||
|
// Start with the parent records
|
||||||
|
currentRecords := modelValue
|
||||||
|
|
||||||
|
// Load each level of the relation
|
||||||
|
for i, relationPart := range relationParts {
|
||||||
|
isLastPart := i == len(relationParts)-1
|
||||||
|
|
||||||
|
logger.Debug("Loading relation part [%d/%d]: %s", i+1, len(relationParts), relationPart)
|
||||||
|
|
||||||
|
// Check if this level is already loaded by Bun (avoid duplicates)
|
||||||
|
existingRecords, alreadyLoaded := checkIfRelationAlreadyLoaded(currentRecords, relationPart)
|
||||||
|
if alreadyLoaded && existingRecords.IsValid() && existingRecords.Len() > 0 {
|
||||||
|
logger.Info("Relation '%s' already loaded by Bun, using existing %d records", relationPart, existingRecords.Len())
|
||||||
|
currentRecords = existingRecords
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load this level and get the loaded records for the next level
|
||||||
|
loadedRecords, err := b.loadRelationLevel(ctx, currentRecords, relationPart, isLastPart, applyFuncs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load relation %s (part %s): %w", relation, relationPart, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For nested relations, use the loaded records as parents for the next level
|
||||||
|
if !isLastPart && loadedRecords.IsValid() && loadedRecords.Len() > 0 {
|
||||||
|
logger.Debug("Collected %d records for next level", loadedRecords.Len())
|
||||||
|
currentRecords = loadedRecords
|
||||||
|
} else if !isLastPart {
|
||||||
|
logger.Debug("No records loaded at level %s, stopping nested preload", relationPart)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRelationLevel loads a single level of a relation for a set of parent records
|
||||||
|
// Returns the loaded records (for use as parents in nested preloads) and any error
|
||||||
|
func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords reflect.Value, relationName string, isLast bool, applyFuncs []func(common.SelectQuery) common.SelectQuery) (reflect.Value, error) {
|
||||||
|
if parentRecords.Len() == 0 {
|
||||||
|
return reflect.Value{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first record to inspect the struct type
|
||||||
|
firstRecord := parentRecords.Index(0)
|
||||||
|
if firstRecord.Kind() == reflect.Ptr {
|
||||||
|
firstRecord = firstRecord.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstRecord.Kind() != reflect.Struct {
|
||||||
|
return reflect.Value{}, fmt.Errorf("expected struct, got %v", firstRecord.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
parentType := firstRecord.Type()
|
||||||
|
|
||||||
|
// Find the relation field in the struct
|
||||||
|
structField, found := parentType.FieldByName(relationName)
|
||||||
|
if !found {
|
||||||
|
return reflect.Value{}, fmt.Errorf("relation field %s not found in struct %s", relationName, parentType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the bun tag to get relation info
|
||||||
|
bunTag := structField.Tag.Get("bun")
|
||||||
|
logger.Debug("Relation %s bun tag: %s", relationName, bunTag)
|
||||||
|
|
||||||
|
relInfo, err := parseRelationTag(bunTag)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, fmt.Errorf("failed to parse relation tag for %s: %w", relationName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Parsed relation: type=%s, join=%s", relInfo.relType, relInfo.joinCondition)
|
||||||
|
|
||||||
|
// Extract foreign key values from parent records
|
||||||
|
fkValues, err := extractForeignKeyValues(parentRecords, relInfo.localKey)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, fmt.Errorf("failed to extract FK values: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fkValues) == 0 {
|
||||||
|
logger.Debug("No foreign key values to load for relation %s", relationName)
|
||||||
|
return reflect.Value{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Loading %d related records for %s (FK values: %v)", len(fkValues), relationName, fkValues)
|
||||||
|
|
||||||
|
// Get the related model type
|
||||||
|
relatedType := structField.Type
|
||||||
|
isSlice := relatedType.Kind() == reflect.Slice
|
||||||
|
if isSlice {
|
||||||
|
relatedType = relatedType.Elem()
|
||||||
|
}
|
||||||
|
if relatedType.Kind() == reflect.Ptr {
|
||||||
|
relatedType = relatedType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a slice to hold the results
|
||||||
|
resultsSlice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(relatedType)), 0, len(fkValues))
|
||||||
|
resultsPtr := reflect.New(resultsSlice.Type())
|
||||||
|
resultsPtr.Elem().Set(resultsSlice)
|
||||||
|
|
||||||
|
// Build and execute the query
|
||||||
|
query := b.db.NewSelect().Model(resultsPtr.Interface())
|
||||||
|
|
||||||
|
// Apply WHERE clause: foreign_key IN (values...)
|
||||||
|
query = query.Where(fmt.Sprintf("%s IN (?)", relInfo.foreignKey), bun.In(fkValues))
|
||||||
|
|
||||||
|
// Apply user's functions (if any)
|
||||||
|
if isLast && len(applyFuncs) > 0 {
|
||||||
|
wrapper := &BunSelectQuery{query: query, db: b.db}
|
||||||
|
for _, fn := range applyFuncs {
|
||||||
|
if fn != nil {
|
||||||
|
wrapper = fn(wrapper).(*BunSelectQuery)
|
||||||
|
query = wrapper.query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the query
|
||||||
|
err = query.Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, fmt.Errorf("failed to load related records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedRecords := resultsPtr.Elem()
|
||||||
|
logger.Info("Loaded %d related records for relation %s", loadedRecords.Len(), relationName)
|
||||||
|
|
||||||
|
// Associate loaded records back to parent records
|
||||||
|
err = associateRelatedRecords(parentRecords, loadedRecords, relationName, relInfo, isSlice)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the loaded records for use in nested preloads
|
||||||
|
return loadedRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// relationInfo holds parsed relation metadata
|
||||||
|
type relationInfo struct {
|
||||||
|
relType string // has-one, has-many, belongs-to
|
||||||
|
localKey string // Key in parent table
|
||||||
|
foreignKey string // Key in related table
|
||||||
|
joinCondition string // Full join condition
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRelationTag parses the bun:"rel:..." tag
|
||||||
|
func parseRelationTag(tag string) (*relationInfo, error) {
|
||||||
|
info := &relationInfo{}
|
||||||
|
|
||||||
|
// Parse tag like: rel:has-one,join:rid_mastertaskitem=rid_mastertaskitem
|
||||||
|
parts := strings.Split(tag, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "rel:") {
|
||||||
|
info.relType = strings.TrimPrefix(part, "rel:")
|
||||||
|
} else if strings.HasPrefix(part, "join:") {
|
||||||
|
info.joinCondition = strings.TrimPrefix(part, "join:")
|
||||||
|
// Parse join: local_key=foreign_key
|
||||||
|
joinParts := strings.Split(info.joinCondition, "=")
|
||||||
|
if len(joinParts) == 2 {
|
||||||
|
info.localKey = strings.TrimSpace(joinParts[0])
|
||||||
|
info.foreignKey = strings.TrimSpace(joinParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.relType == "" || info.localKey == "" || info.foreignKey == "" {
|
||||||
|
return nil, fmt.Errorf("incomplete relation tag: %s", tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractForeignKeyValues collects FK values from parent records
|
||||||
|
func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]interface{}, error) {
|
||||||
|
values := make([]interface{}, 0, records.Len())
|
||||||
|
seenValues := make(map[interface{}]bool)
|
||||||
|
|
||||||
|
for i := 0; i < records.Len(); i++ {
|
||||||
|
record := records.Index(i)
|
||||||
|
if record.Kind() == reflect.Ptr {
|
||||||
|
record = record.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the FK field - try both exact name and capitalized version
|
||||||
|
fkField := record.FieldByName(fkFieldName)
|
||||||
|
if !fkField.IsValid() {
|
||||||
|
// Try capitalized version
|
||||||
|
fkField = record.FieldByName(strings.ToUpper(fkFieldName[:1]) + fkFieldName[1:])
|
||||||
|
}
|
||||||
|
if !fkField.IsValid() {
|
||||||
|
// Try finding by json tag
|
||||||
|
for j := 0; j < record.NumField(); j++ {
|
||||||
|
field := record.Type().Field(j)
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.HasPrefix(jsonTag, fkFieldName) || strings.Contains(bunTag, fkFieldName) {
|
||||||
|
fkField = record.Field(j)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fkField.IsValid() {
|
||||||
|
continue // Skip records without FK
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the value
|
||||||
|
var value interface{}
|
||||||
|
if fkField.CanInterface() {
|
||||||
|
value = fkField.Interface()
|
||||||
|
|
||||||
|
// Handle SqlNull types
|
||||||
|
if nullType, ok := value.(interface{ IsNull() bool }); ok {
|
||||||
|
if nullType.IsNull() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle types with Int64() method
|
||||||
|
if int64er, ok := value.(interface{ Int64() int64 }); ok {
|
||||||
|
value = int64er.Int64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate
|
||||||
|
if !seenValues[value] {
|
||||||
|
values = append(values, value)
|
||||||
|
seenValues[value] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// associateRelatedRecords associates loaded records back to parents
|
||||||
|
func associateRelatedRecords(parents, related reflect.Value, fieldName string, relInfo *relationInfo, isSlice bool) error {
|
||||||
|
logger.Debug("Associating %d related records to %d parents for field '%s'", related.Len(), parents.Len(), fieldName)
|
||||||
|
|
||||||
|
// Build a map: foreignKey -> related record(s)
|
||||||
|
relatedMap := make(map[interface{}][]reflect.Value)
|
||||||
|
|
||||||
|
for i := 0; i < related.Len(); i++ {
|
||||||
|
relRecord := related.Index(i)
|
||||||
|
relRecordElem := relRecord
|
||||||
|
if relRecordElem.Kind() == reflect.Ptr {
|
||||||
|
relRecordElem = relRecordElem.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the foreign key value from the related record - try multiple variations
|
||||||
|
fkField := findFieldByName(relRecordElem, relInfo.foreignKey)
|
||||||
|
if !fkField.IsValid() {
|
||||||
|
logger.Warn("Could not find FK field '%s' in related record type %s", relInfo.foreignKey, relRecordElem.Type().Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fkValue := extractFieldValue(fkField)
|
||||||
|
if fkValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
relatedMap[fkValue] = append(relatedMap[fkValue], related.Index(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Built related map with %d unique FK values", len(relatedMap))
|
||||||
|
|
||||||
|
// Associate with parents
|
||||||
|
associatedCount := 0
|
||||||
|
for i := 0; i < parents.Len(); i++ {
|
||||||
|
parentPtr := parents.Index(i)
|
||||||
|
parent := parentPtr
|
||||||
|
if parent.Kind() == reflect.Ptr {
|
||||||
|
parent = parent.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the local key value from parent
|
||||||
|
localField := findFieldByName(parent, relInfo.localKey)
|
||||||
|
if !localField.IsValid() {
|
||||||
|
logger.Warn("Could not find local key field '%s' in parent type %s", relInfo.localKey, parent.Type().Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
localValue := extractFieldValue(localField)
|
||||||
|
if localValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find matching related records
|
||||||
|
matches := relatedMap[localValue]
|
||||||
|
if len(matches) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the relation field - IMPORTANT: use the pointer, not the elem
|
||||||
|
relationField := parent.FieldByName(fieldName)
|
||||||
|
if !relationField.IsValid() {
|
||||||
|
logger.Warn("Relation field '%s' not found in parent type %s", fieldName, parent.Type().Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !relationField.CanSet() {
|
||||||
|
logger.Warn("Relation field '%s' cannot be set (unexported?)", fieldName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isSlice {
|
||||||
|
// For has-many: replace entire slice (don't append to avoid duplicates)
|
||||||
|
newSlice := reflect.MakeSlice(relationField.Type(), 0, len(matches))
|
||||||
|
for _, match := range matches {
|
||||||
|
newSlice = reflect.Append(newSlice, match)
|
||||||
|
}
|
||||||
|
relationField.Set(newSlice)
|
||||||
|
associatedCount += len(matches)
|
||||||
|
logger.Debug("Set has-many field '%s' with %d records for parent %d", fieldName, len(matches), i)
|
||||||
|
} else {
|
||||||
|
// For has-one/belongs-to: only set if not already set (avoid duplicates)
|
||||||
|
if relationField.IsNil() {
|
||||||
|
relationField.Set(matches[0])
|
||||||
|
associatedCount++
|
||||||
|
logger.Debug("Set has-one field '%s' for parent %d", fieldName, i)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping has-one field '%s' for parent %d (already set)", fieldName, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Associated %d related records to %d parents for field '%s'", associatedCount, parents.Len(), fieldName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findFieldByName finds a struct field by name, trying multiple variations
|
||||||
|
func findFieldByName(v reflect.Value, name string) reflect.Value {
|
||||||
|
// Try exact name
|
||||||
|
field := v.FieldByName(name)
|
||||||
|
if field.IsValid() {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try with capital first letter
|
||||||
|
if len(name) > 0 {
|
||||||
|
capital := strings.ToUpper(name[0:1]) + name[1:]
|
||||||
|
field = v.FieldByName(capital)
|
||||||
|
if field.IsValid() {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try searching by json or bun tag
|
||||||
|
t := v.Type()
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
f := t.Field(i)
|
||||||
|
jsonTag := f.Tag.Get("json")
|
||||||
|
bunTag := f.Tag.Get("bun")
|
||||||
|
|
||||||
|
// Check json tag
|
||||||
|
if strings.HasPrefix(jsonTag, name+",") || jsonTag == name {
|
||||||
|
return v.Field(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check bun tag for column name
|
||||||
|
if strings.Contains(bunTag, name+",") || strings.Contains(bunTag, name+":") {
|
||||||
|
return v.Field(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.Value{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFieldValue extracts the value from a field, handling SqlNull types
|
||||||
|
func extractFieldValue(field reflect.Value) interface{} {
|
||||||
|
if !field.CanInterface() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value := field.Interface()
|
||||||
|
|
||||||
|
// Handle SqlNull types
|
||||||
|
if nullType, ok := value.(interface{ IsNull() bool }); ok {
|
||||||
|
if nullType.IsNull() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle types with Int64() method
|
||||||
|
if int64er, ok := value.(interface{ Int64() int64 }); ok {
|
||||||
|
return int64er.Int64()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle types with String() method for comparison
|
||||||
|
if stringer, ok := value.(interface{ String() string }); ok {
|
||||||
|
return stringer.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
// JoinRelation uses a LEFT JOIN instead of a separate query
|
// JoinRelation uses a LEFT JOIN instead of a separate query
|
||||||
// This is more efficient for many-to-one or one-to-one relationships
|
// This is more efficient for many-to-one or one-to-one relationships
|
||||||
@@ -683,6 +1106,10 @@ func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.Sele
|
|||||||
|
|
||||||
// Use PreloadRelation with the wrapped functions
|
// Use PreloadRelation with the wrapped functions
|
||||||
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
||||||
|
// CRITICAL: Set skipAutoDetect flag to prevent circular call
|
||||||
|
// (PreloadRelation would detect belongs-to and call JoinRelation again)
|
||||||
|
b.skipAutoDetect = true
|
||||||
|
defer func() { b.skipAutoDetect = false }()
|
||||||
return b.PreloadRelation(relation, wrappedApply...)
|
return b.PreloadRelation(relation, wrappedApply...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -726,7 +1153,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
|||||||
return fmt.Errorf("destination cannot be nil")
|
return fmt.Errorf("destination cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the main query first
|
|
||||||
err = b.query.Scan(ctx, dest)
|
err = b.query.Scan(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
@@ -735,15 +1161,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute any deferred preloads
|
|
||||||
if len(b.deferredPreloads) > 0 {
|
|
||||||
err = b.executeDeferredPreloads(ctx, dest)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
|
||||||
// Don't fail the whole query, just log the warning
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -793,7 +1210,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the main query first
|
|
||||||
err = b.query.Scan(ctx)
|
err = b.query.Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
@@ -802,128 +1218,18 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute any deferred preloads
|
// After main query, load custom preloads using separate queries
|
||||||
if len(b.deferredPreloads) > 0 {
|
if len(b.customPreloads) > 0 {
|
||||||
model := b.query.GetModel()
|
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
||||||
err = b.executeDeferredPreloads(ctx, model.Value())
|
if err := b.loadCustomPreloads(ctx); err != nil {
|
||||||
if err != nil {
|
logger.Error("Failed to load custom preloads: %v", err)
|
||||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
return err
|
||||||
// Don't fail the whole query, just log the warning
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
|
|
||||||
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
|
|
||||||
if len(b.deferredPreloads) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dp := range b.deferredPreloads {
|
|
||||||
err := b.executeSingleDeferredPreload(ctx, dest, dp)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeSingleDeferredPreload executes a single deferred preload
|
|
||||||
// For a relation like "Parent.Child", it:
|
|
||||||
// 1. Finds all loaded Parent records in dest
|
|
||||||
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
|
|
||||||
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
|
|
||||||
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
|
|
||||||
relationParts := strings.Split(dp.relation, ".")
|
|
||||||
if len(relationParts) < 2 {
|
|
||||||
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The parent relation that was already loaded
|
|
||||||
parentRelation := relationParts[0]
|
|
||||||
// The child relation we need to load
|
|
||||||
childRelation := strings.Join(relationParts[1:], ".")
|
|
||||||
|
|
||||||
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
|
|
||||||
|
|
||||||
// Use reflection to access the parent relation field(s) in the loaded records
|
|
||||||
// Then load the child relation for those parent records
|
|
||||||
destValue := reflect.ValueOf(dest)
|
|
||||||
if destValue.Kind() == reflect.Ptr {
|
|
||||||
destValue = destValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle both slice and single record
|
|
||||||
if destValue.Kind() == reflect.Slice {
|
|
||||||
// Iterate through each record in the slice
|
|
||||||
for i := 0; i < destValue.Len(); i++ {
|
|
||||||
record := destValue.Index(i)
|
|
||||||
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
|
|
||||||
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
|
|
||||||
// Continue with other records
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Single record
|
|
||||||
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
|
|
||||||
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadChildRelationForRecord loads a child relation for a single parent record
|
|
||||||
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
|
|
||||||
// Ensure we're working with the actual struct value, not a pointer
|
|
||||||
if record.Kind() == reflect.Ptr {
|
|
||||||
record = record.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the parent relation field
|
|
||||||
parentField := record.FieldByName(parentRelation)
|
|
||||||
if !parentField.IsValid() {
|
|
||||||
// Parent relation field doesn't exist
|
|
||||||
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the parent field is nil (for pointer fields)
|
|
||||||
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
|
|
||||||
// Parent relation not loaded or nil, skip
|
|
||||||
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the interface value to pass to Bun
|
|
||||||
parentValue := parentField.Interface()
|
|
||||||
|
|
||||||
// Load the child relation on the parent record
|
|
||||||
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
|
|
||||||
return b.db.NewSelect().
|
|
||||||
Model(parentValue).
|
|
||||||
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
|
||||||
// Apply any custom query modifications
|
|
||||||
if len(apply) > 0 {
|
|
||||||
wrapper := &BunSelectQuery{query: sq, db: b.db}
|
|
||||||
current := common.SelectQuery(wrapper)
|
|
||||||
for _, fn := range apply {
|
|
||||||
if fn != nil {
|
|
||||||
current = fn(current)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if finalBun, ok := current.(*BunSelectQuery); ok {
|
|
||||||
return finalBun.query
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sq
|
|
||||||
}).
|
|
||||||
Scan(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
|||||||
@@ -1,9 +1,63 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun/dialect/mssqldialect"
|
||||||
|
"github.com/uptrace/bun/dialect/pgdialect"
|
||||||
|
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/driver/sqlserver"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
|
||||||
|
const postgresIdentifierLimit = 63
|
||||||
|
|
||||||
|
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
|
||||||
|
// Returns true if the alias is likely to be truncated
|
||||||
|
func checkAliasLength(relation string) bool {
|
||||||
|
// Bun generates aliases like: parentalias__childalias__columnname
|
||||||
|
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
|
||||||
|
parts := strings.Split(relation, ".")
|
||||||
|
if len(parts) <= 1 {
|
||||||
|
return false // Single level relations are fine
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the actual alias prefix length that Bun will generate
|
||||||
|
// Bun uses double underscores (__) between each relation level
|
||||||
|
// and converts the relation names to lowercase with underscores
|
||||||
|
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
|
||||||
|
aliasPrefixLen := len(aliasPrefix)
|
||||||
|
|
||||||
|
// We need to add 2 more underscores for the column name separator plus column name length
|
||||||
|
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
|
||||||
|
// To be safe, assume the longest column name could be around 35 chars
|
||||||
|
maxColumnNameLen := 35
|
||||||
|
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
|
||||||
|
|
||||||
|
// Check if this would exceed PostgreSQL's identifier limit
|
||||||
|
if estimatedMaxLen > postgresIdentifierLimit {
|
||||||
|
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
|
||||||
|
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check if just the prefix is getting close (within 15 chars of limit)
|
||||||
|
// This gives room for column names
|
||||||
|
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
|
||||||
|
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
|
||||||
|
relation, aliasPrefixLen, postgresIdentifierLimit)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
// For example: "public.users" -> ("public", "users")
|
// For example: "public.users" -> ("public", "users")
|
||||||
//
|
//
|
||||||
@@ -14,3 +68,39 @@ func parseTableName(fullTableName string) (schema, table string) {
|
|||||||
}
|
}
|
||||||
return "", fullTableName
|
return "", fullTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPostgresDialect returns a Bun PostgreSQL dialect
|
||||||
|
func GetPostgresDialect() *pgdialect.Dialect {
|
||||||
|
return pgdialect.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLiteDialect returns a Bun SQLite dialect
|
||||||
|
func GetSQLiteDialect() *sqlitedialect.Dialect {
|
||||||
|
return sqlitedialect.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMSSQLDialect returns a Bun MSSQL dialect
|
||||||
|
func GetMSSQLDialect() *mssqldialect.Dialect {
|
||||||
|
return mssqldialect.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostgresDialector returns a GORM PostgreSQL dialector
|
||||||
|
func GetPostgresDialector(db *sql.DB) gorm.Dialector {
|
||||||
|
return postgres.New(postgres.Config{
|
||||||
|
Conn: db,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLiteDialector returns a GORM SQLite dialector
|
||||||
|
func GetSQLiteDialector(db *sql.DB) gorm.Dialector {
|
||||||
|
return sqlite.Dialector{
|
||||||
|
Conn: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMSSQLDialector returns a GORM MSSQL dialector
|
||||||
|
func GetMSSQLDialector(db *sql.DB) gorm.Dialector {
|
||||||
|
return sqlserver.New(sqlserver.Config{
|
||||||
|
Conn: db,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CORSConfig holds CORS configuration
|
// CORSConfig holds CORS configuration
|
||||||
@@ -15,8 +17,30 @@ type CORSConfig struct {
|
|||||||
|
|
||||||
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
|
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
|
||||||
func DefaultCORSConfig() CORSConfig {
|
func DefaultCORSConfig() CORSConfig {
|
||||||
|
configManager := config.GetConfigManager()
|
||||||
|
cfg, _ := configManager.GetConfig()
|
||||||
|
hosts := make([]string, 0)
|
||||||
|
// hosts = append(hosts, "*")
|
||||||
|
|
||||||
|
_, _, ipsList := config.GetIPs()
|
||||||
|
|
||||||
|
for i := range cfg.Servers.Instances {
|
||||||
|
server := cfg.Servers.Instances[i]
|
||||||
|
if server.Port == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hosts = append(hosts, server.ExternalURLs...)
|
||||||
|
hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port))
|
||||||
|
hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port))
|
||||||
|
hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port))
|
||||||
|
for _, ip := range ipsList {
|
||||||
|
hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port))
|
||||||
|
hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return CORSConfig{
|
return CORSConfig{
|
||||||
AllowedOrigins: []string{"*"},
|
AllowedOrigins: hosts,
|
||||||
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||||
AllowedHeaders: GetHeadSpecHeaders(),
|
AllowedHeaders: GetHeadSpecHeaders(),
|
||||||
MaxAge: 86400, // 24 hours
|
MaxAge: 86400, // 24 hours
|
||||||
@@ -90,11 +114,14 @@ func GetHeadSpecHeaders() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetCORSHeaders sets CORS headers on a response writer
|
// SetCORSHeaders sets CORS headers on a response writer
|
||||||
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
func SetCORSHeaders(w ResponseWriter, r Request, config CORSConfig) {
|
||||||
// Set allowed origins
|
// Set allowed origins
|
||||||
if len(config.AllowedOrigins) > 0 {
|
// if len(config.AllowedOrigins) > 0 {
|
||||||
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
// w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
// Todo origin list parsing
|
||||||
|
w.SetHeader("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
// Set allowed methods
|
// Set allowed methods
|
||||||
if len(config.AllowedMethods) > 0 {
|
if len(config.AllowedMethods) > 0 {
|
||||||
@@ -102,9 +129,10 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set allowed headers
|
// Set allowed headers
|
||||||
if len(config.AllowedHeaders) > 0 {
|
// if len(config.AllowedHeaders) > 0 {
|
||||||
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
// w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||||
}
|
// }
|
||||||
|
w.SetHeader("Access-Control-Allow-Headers", "*")
|
||||||
|
|
||||||
// Set max age
|
// Set max age
|
||||||
if config.MaxAge > 0 {
|
if config.MaxAge > 0 {
|
||||||
@@ -115,5 +143,7 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
|||||||
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
// Expose headers that clients can read
|
// Expose headers that clients can read
|
||||||
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
|
exposeHeaders := config.AllowedHeaders
|
||||||
|
exposeHeaders = append(exposeHeaders, "Content-Range", "X-Api-Range-Total", "X-Api-Range-Size")
|
||||||
|
w.SetHeader("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", "))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateAndUnwrapModelResult contains the result of model validation
|
// ValidateAndUnwrapModelResult contains the result of model validation
|
||||||
@@ -45,3 +48,216 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
|
|||||||
OriginalType: originalType,
|
OriginalType: originalType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractTagValue extracts the value for a given key from a struct tag string.
|
||||||
|
// It handles both semicolon and comma-separated tag formats (e.g., GORM and BUN tags).
|
||||||
|
// For tags like "json:name;validate:required" it will extract "name" for key "json".
|
||||||
|
// For tags like "rel:has-many,join:table" it will extract "table" for key "join".
|
||||||
|
func ExtractTagValue(tag, key string) string {
|
||||||
|
// Split by both semicolons and commas to handle different tag formats
|
||||||
|
// We need to be smart about this - commas can be part of values
|
||||||
|
// So we'll try semicolon first, then comma if needed
|
||||||
|
separators := []string{";", ","}
|
||||||
|
|
||||||
|
for _, sep := range separators {
|
||||||
|
parts := strings.Split(tag, sep)
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, key+":") {
|
||||||
|
return strings.TrimPrefix(part, key+":")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationshipInfo analyzes a model type and extracts relationship metadata
|
||||||
|
// for a specific relation field identified by its JSON name.
|
||||||
|
// Returns nil if the field is not found or is not a valid relationship.
|
||||||
|
func GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
|
||||||
|
// Ensure we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
jsonName := strings.Split(jsonTag, ",")[0]
|
||||||
|
|
||||||
|
if jsonName == relationName {
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
info := &RelationshipInfo{
|
||||||
|
FieldName: field.Name,
|
||||||
|
JSONName: jsonName,
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") {
|
||||||
|
//bun:"rel:has-many,join:rid_hub=rid_hub_division"
|
||||||
|
if strings.Contains(bunTag, "has-many") {
|
||||||
|
info.RelationType = "hasMany"
|
||||||
|
} else if strings.Contains(bunTag, "has-one") {
|
||||||
|
info.RelationType = "hasOne"
|
||||||
|
} else if strings.Contains(bunTag, "belongs-to") {
|
||||||
|
info.RelationType = "belongsTo"
|
||||||
|
} else if strings.Contains(bunTag, "many-to-many") {
|
||||||
|
info.RelationType = "many2many"
|
||||||
|
} else {
|
||||||
|
info.RelationType = "hasOne"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract join info
|
||||||
|
joinPart := ExtractTagValue(bunTag, "join")
|
||||||
|
if joinPart != "" && info.RelationType == "many2many" {
|
||||||
|
// For many2many, the join part is the join table name
|
||||||
|
info.JoinTable = joinPart
|
||||||
|
} else if joinPart != "" {
|
||||||
|
// For other relations, parse foreignKey and references
|
||||||
|
joinParts := strings.Split(joinPart, "=")
|
||||||
|
if len(joinParts) == 2 {
|
||||||
|
info.ForeignKey = joinParts[0]
|
||||||
|
info.References = joinParts[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get related model type
|
||||||
|
if field.Type.Kind() == reflect.Slice {
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||||
|
elemType := field.Type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse GORM tag to determine relationship type and keys
|
||||||
|
if strings.Contains(gormTag, "foreignKey") {
|
||||||
|
info.ForeignKey = ExtractTagValue(gormTag, "foreignKey")
|
||||||
|
info.References = ExtractTagValue(gormTag, "references")
|
||||||
|
|
||||||
|
// Determine if it's belongsTo or hasMany/hasOne
|
||||||
|
if field.Type.Kind() == reflect.Slice {
|
||||||
|
info.RelationType = "hasMany"
|
||||||
|
// Get the element type for slice
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||||
|
info.RelationType = "belongsTo"
|
||||||
|
elemType := field.Type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if strings.Contains(gormTag, "many2many") {
|
||||||
|
info.RelationType = "many2many"
|
||||||
|
info.JoinTable = ExtractTagValue(gormTag, "many2many")
|
||||||
|
// Get the element type for many2many (always slice)
|
||||||
|
if field.Type.Kind() == reflect.Slice {
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Field has no GORM relationship tags, so it's not a relation
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelationPathToBunAlias converts a relation path (e.g., "Order.Customer") to a Bun alias format.
|
||||||
|
// It converts to lowercase and replaces dots with double underscores.
|
||||||
|
// For example: "Order.Customer" -> "order__customer"
|
||||||
|
func RelationPathToBunAlias(relationPath string) string {
|
||||||
|
if relationPath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Convert to lowercase and replace dots with double underscores
|
||||||
|
alias := strings.ToLower(relationPath)
|
||||||
|
alias = strings.ReplaceAll(alias, ".", "__")
|
||||||
|
return alias
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||||
|
// with the appropriate alias for the current preload level.
|
||||||
|
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||||
|
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||||
|
func ReplaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||||
|
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||||
|
return sqlExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace both quoted and unquoted table references
|
||||||
|
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||||
|
|
||||||
|
// Pattern 1: tablename.column (unquoted)
|
||||||
|
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||||
|
|
||||||
|
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||||
|
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTableNameFromModel extracts the table name from a model.
|
||||||
|
// It checks the bun tag first, then falls back to converting the struct name to snake_case.
|
||||||
|
func GetTableNameFromModel(model interface{}) string {
|
||||||
|
if model == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for bun tag on embedded BaseModel
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if field.Anonymous {
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.HasPrefix(bunTag, "table:") {
|
||||||
|
return strings.TrimPrefix(bunTag, "table:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||||
|
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||||
|
return strings.ToLower(modelType.Name())
|
||||||
|
}
|
||||||
|
|||||||
108
pkg/common/handler_utils_test.go
Normal file
108
pkg/common/handler_utils_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractTagValue(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tag string
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Extract existing key",
|
||||||
|
tag: "json:name;validate:required",
|
||||||
|
key: "json",
|
||||||
|
expected: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key with spaces",
|
||||||
|
tag: "json:name ; validate:required",
|
||||||
|
key: "validate",
|
||||||
|
expected: "required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key at end",
|
||||||
|
tag: "json:name;validate:required;db:column_name",
|
||||||
|
key: "db",
|
||||||
|
expected: "column_name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key at beginning",
|
||||||
|
tag: "primary:true;json:id;db:user_id",
|
||||||
|
key: "primary",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key not found",
|
||||||
|
tag: "json:name;validate:required",
|
||||||
|
key: "db",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty tag",
|
||||||
|
tag: "",
|
||||||
|
key: "json",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single key-value pair",
|
||||||
|
tag: "json:name",
|
||||||
|
key: "json",
|
||||||
|
expected: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key with empty value",
|
||||||
|
tag: "json:;validate:required",
|
||||||
|
key: "json",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key with complex value",
|
||||||
|
tag: "json:user_name,omitempty;validate:required,min=3",
|
||||||
|
key: "json",
|
||||||
|
expected: "user_name,omitempty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple semicolons",
|
||||||
|
tag: "json:name;;validate:required",
|
||||||
|
key: "validate",
|
||||||
|
expected: "required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BUN Tag with comma separator",
|
||||||
|
tag: "rel:has-many,join:rid_hub=rid_hub_child",
|
||||||
|
key: "join",
|
||||||
|
expected: "rid_hub=rid_hub_child",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract foreignKey",
|
||||||
|
tag: "foreignKey:UserID;references:ID",
|
||||||
|
key: "foreignKey",
|
||||||
|
expected: "UserID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract references",
|
||||||
|
tag: "foreignKey:UserID;references:ID",
|
||||||
|
key: "references",
|
||||||
|
expected: "ID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract many2many",
|
||||||
|
tag: "many2many:user_roles",
|
||||||
|
key: "many2many",
|
||||||
|
expected: "user_roles",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExtractTagValue(tt.tag, tt.key)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -20,17 +20,6 @@ type RelationshipInfoProvider interface {
|
|||||||
GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo
|
GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// RelationshipInfo contains information about a model relationship
|
|
||||||
type RelationshipInfo struct {
|
|
||||||
FieldName string
|
|
||||||
JSONName string
|
|
||||||
RelationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
|
||||||
ForeignKey string
|
|
||||||
References string
|
|
||||||
JoinTable string
|
|
||||||
RelatedModel interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NestedCUDProcessor handles recursive processing of nested object graphs
|
// NestedCUDProcessor handles recursive processing of nested object graphs
|
||||||
type NestedCUDProcessor struct {
|
type NestedCUDProcessor struct {
|
||||||
db Database
|
db Database
|
||||||
@@ -85,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType)
|
||||||
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,17 +98,27 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Filter regularData to only include fields that exist in the model
|
||||||
|
// Use MapToStruct to validate and filter fields
|
||||||
|
regularData = p.filterValidFields(regularData, model)
|
||||||
|
|
||||||
// Inject parent IDs for foreign key resolution
|
// Inject parent IDs for foreign key resolution
|
||||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||||
|
|
||||||
// Get the primary key name for this model
|
// Get the primary key name for this model
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// Check if we have any data to process (besides _request)
|
||||||
|
hasData := len(regularData) > 0
|
||||||
|
|
||||||
// Process based on operation
|
// Process based on operation
|
||||||
switch strings.ToLower(operation) {
|
switch strings.ToLower(operation) {
|
||||||
case "insert", "create":
|
case "insert", "create":
|
||||||
|
// Only perform insert if we have data to insert
|
||||||
|
if hasData {
|
||||||
id, err := p.processInsert(ctx, regularData, tableName)
|
id, err := p.processInsert(ctx, regularData, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err)
|
||||||
return nil, fmt.Errorf("insert failed: %w", err)
|
return nil, fmt.Errorf("insert failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = id
|
result.ID = id
|
||||||
@@ -126,13 +126,20 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
// Process child relations after parent insert (to get parent ID)
|
// Process child relations after parent insert (to get parent ID)
|
||||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
|
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
|
||||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
case "update":
|
case "update":
|
||||||
|
// Only perform update if we have data to update
|
||||||
|
if hasData {
|
||||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err)
|
||||||
return nil, fmt.Errorf("update failed: %w", err)
|
return nil, fmt.Errorf("update failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data[pkName]
|
result.ID = data[pkName]
|
||||||
@@ -140,18 +147,25 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
// Process child relations for update
|
// Process child relations for update
|
||||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
|
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping update for %s - no data columns besides _request", tableName)
|
||||||
|
result.ID = data[pkName]
|
||||||
|
}
|
||||||
|
|
||||||
case "delete":
|
case "delete":
|
||||||
// Process child relations first (for referential integrity)
|
// Process child relations first (for referential integrity)
|
||||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
|
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err)
|
||||||
return nil, fmt.Errorf("delete failed: %w", err)
|
return nil, fmt.Errorf("delete failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data[pkName]
|
result.ID = data[pkName]
|
||||||
@@ -159,6 +173,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
logger.Error("Unsupported operation: %s for table=%s", operation, tableName)
|
||||||
return nil, fmt.Errorf("unsupported operation: %s", operation)
|
return nil, fmt.Errorf("unsupported operation: %s", operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterValidFields filters input data to only include fields that exist in the model
|
||||||
|
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
||||||
|
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new instance of the model to use with MapToStruct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new instance of the model
|
||||||
|
tempModel := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
|
// Use MapToStruct to map the data - this will only map valid fields
|
||||||
|
err := reflection.MapToStruct(data, tempModel)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Error mapping data to model: %v", err)
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the mapped fields back into a map
|
||||||
|
// This effectively filters out any fields that don't exist in the model
|
||||||
|
filteredData := make(map[string]interface{})
|
||||||
|
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
||||||
|
|
||||||
|
for key, value := range data {
|
||||||
|
// Check if the field was successfully mapped
|
||||||
|
if fieldWasMapped(tempModelValue, modelType, key) {
|
||||||
|
filteredData[key] = value
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredData
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldWasMapped checks if a field with the given key was mapped to the model
|
||||||
|
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
||||||
|
// Look for the field by JSON tag or field name
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Skip unexported fields
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check bun tag
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && bunTag != "-" {
|
||||||
|
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check gorm tag
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if gormTag != "" && gormTag != "-" {
|
||||||
|
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check lowercase field name
|
||||||
|
if strings.EqualFold(field.Name, key) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle embedded structs recursively
|
||||||
|
if field.Anonymous {
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
embeddedValue := modelValue.Field(i)
|
||||||
|
if embeddedValue.Kind() == reflect.Ptr {
|
||||||
|
if embeddedValue.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embeddedValue = embeddedValue.Elem()
|
||||||
|
}
|
||||||
|
if fieldWasMapped(embeddedValue, fieldType, key) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||||
if len(parentIDs) == 0 {
|
if len(parentIDs) == 0 {
|
||||||
@@ -218,12 +342,13 @@ func (p *NestedCUDProcessor) processInsert(
|
|||||||
for key, value := range data {
|
for key, value := range data {
|
||||||
query = query.Value(key, value)
|
query = query.Value(key, value)
|
||||||
}
|
}
|
||||||
|
pkName := reflection.GetPrimaryKeyName(tableName)
|
||||||
// Add RETURNING clause to get the inserted ID
|
// Add RETURNING clause to get the inserted ID
|
||||||
query = query.Returning("id")
|
query = query.Returning(pkName)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
|
||||||
return nil, fmt.Errorf("insert exec failed: %w", err)
|
return nil, fmt.Errorf("insert exec failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,8 +356,8 @@ func (p *NestedCUDProcessor) processInsert(
|
|||||||
var id interface{}
|
var id interface{}
|
||||||
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
|
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
|
||||||
id = lastID
|
id = lastID
|
||||||
} else if data["id"] != nil {
|
} else if data[pkName] != nil {
|
||||||
id = data["id"]
|
id = data[pkName]
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
|
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
|
||||||
@@ -247,6 +372,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
|||||||
id interface{},
|
id interface{},
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
if id == nil {
|
if id == nil {
|
||||||
|
logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data)
|
||||||
return 0, fmt.Errorf("update requires an ID")
|
return 0, fmt.Errorf("update requires an ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,6 +382,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
|||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err)
|
||||||
return 0, fmt.Errorf("update exec failed: %w", err)
|
return 0, fmt.Errorf("update exec failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,6 +394,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
|||||||
// processDelete handles delete operation
|
// processDelete handles delete operation
|
||||||
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
|
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
|
||||||
if id == nil {
|
if id == nil {
|
||||||
|
logger.Error("Delete requires an ID: table=%s", tableName)
|
||||||
return 0, fmt.Errorf("delete requires an ID")
|
return 0, fmt.Errorf("delete requires an ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,6 +404,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
|
|||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err)
|
||||||
return 0, fmt.Errorf("delete exec failed: %w", err)
|
return 0, fmt.Errorf("delete exec failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,6 +421,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
relationFields map[string]*RelationshipInfo,
|
relationFields map[string]*RelationshipInfo,
|
||||||
relationData map[string]interface{},
|
relationData map[string]interface{},
|
||||||
parentModelType reflect.Type,
|
parentModelType reflect.Type,
|
||||||
|
incomingParentIDs map[string]interface{}, // IDs from all ancestors
|
||||||
) error {
|
) error {
|
||||||
for relationName, relInfo := range relationFields {
|
for relationName, relInfo := range relationFields {
|
||||||
relationValue, exists := relationData[relationName]
|
relationValue, exists := relationData[relationName]
|
||||||
@@ -304,7 +434,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
// Get the related model
|
// Get the related model
|
||||||
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||||
if !found {
|
if !found {
|
||||||
logger.Warn("Field %s not found in model", relInfo.FieldName)
|
logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,20 +454,89 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
|
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
|
||||||
|
|
||||||
// Prepare parent IDs for foreign key injection
|
// Prepare parent IDs for foreign key injection
|
||||||
|
// Start by copying all incoming parent IDs (from ancestors)
|
||||||
parentIDs := make(map[string]interface{})
|
parentIDs := make(map[string]interface{})
|
||||||
if relInfo.ForeignKey != "" {
|
for k, v := range incomingParentIDs {
|
||||||
|
parentIDs[k] = v
|
||||||
|
}
|
||||||
|
logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs)
|
||||||
|
|
||||||
|
// Add the current parent's primary key to the parentIDs map
|
||||||
|
// This ensures nested children have access to all ancestor IDs
|
||||||
|
if parentID != nil && parentModelType != nil {
|
||||||
|
// Get the parent model's primary key field name
|
||||||
|
parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType)
|
||||||
|
if parentPKFieldName != "" {
|
||||||
|
// Get the JSON name for the primary key field
|
||||||
|
parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName)
|
||||||
|
baseName := ""
|
||||||
|
if len(parentPKJSONName) > 1 {
|
||||||
|
baseName = parentPKJSONName
|
||||||
|
} else {
|
||||||
|
// Add parent's PK to the map using the base model name
|
||||||
|
baseName = strings.TrimSuffix(parentPKFieldName, "ID")
|
||||||
|
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||||
|
if baseName == "" {
|
||||||
|
baseName = "parent"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parentIDs[baseName] = parentID
|
||||||
|
logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also add the foreign key reference if specified
|
||||||
|
if relInfo.ForeignKey != "" && parentID != nil {
|
||||||
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
|
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
|
||||||
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||||
|
// Only add if different from what we already added
|
||||||
|
if _, exists := parentIDs[baseName]; !exists {
|
||||||
parentIDs[baseName] = parentID
|
parentIDs[baseName] = parentID
|
||||||
|
logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs)
|
||||||
|
|
||||||
|
// Determine which field name to use for setting parent ID in child data
|
||||||
|
// Priority: Use foreign key field name if specified
|
||||||
|
var foreignKeyFieldName string
|
||||||
|
if relInfo.ForeignKey != "" {
|
||||||
|
// Get the JSON name for the foreign key field in the child model
|
||||||
|
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||||
|
if foreignKeyFieldName == "" {
|
||||||
|
// Fallback to lowercase field name
|
||||||
|
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||||
|
}
|
||||||
|
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||||
|
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||||
|
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||||
|
if childPKFieldName == "" {
|
||||||
|
childPKFieldName = strings.ToLower(childPKName)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Processing relation with foreignKeyField=%s, childPK=%s", foreignKeyFieldName, childPKFieldName)
|
||||||
|
|
||||||
// Process based on relation type and data structure
|
// Process based on relation type and data structure
|
||||||
switch v := relationValue.(type) {
|
switch v := relationValue.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Single related object
|
// Single related object - directly set foreign key if specified
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
v[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||||
|
relationName, relatedTableName, operation, parentID, v, err)
|
||||||
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,24 +544,46 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
// Multiple related objects
|
// Multiple related objects
|
||||||
for i, item := range v {
|
for i, item := range v {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
// Directly set foreign key if specified
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||||
|
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
// Multiple related objects (typed slice)
|
// Multiple related objects (typed slice)
|
||||||
for i, itemMap := range v {
|
for i, itemMap := range v {
|
||||||
|
// Directly set foreign key if specified
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||||
|
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
|
logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
720
pkg/common/recursive_crud_test.go
Normal file
720
pkg/common/recursive_crud_test.go
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mock Database for testing
|
||||||
|
type mockDatabase struct {
|
||||||
|
insertCalls []map[string]interface{}
|
||||||
|
updateCalls []map[string]interface{}
|
||||||
|
deleteCalls []interface{}
|
||||||
|
lastID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockDatabase() *mockDatabase {
|
||||||
|
return &mockDatabase{
|
||||||
|
insertCalls: make([]map[string]interface{}, 0),
|
||||||
|
updateCalls: make([]map[string]interface{}, 0),
|
||||||
|
deleteCalls: make([]interface{}, 0),
|
||||||
|
lastID: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} }
|
||||||
|
func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} }
|
||||||
|
func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} }
|
||||||
|
func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} }
|
||||||
|
func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error {
|
||||||
|
return fn(m)
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||||
|
return &mockResult{rowsAffected: 1}, nil
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) CommitTx(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) RollbackTx(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockDatabase) GetUnderlyingDB() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock SelectQuery
|
||||||
|
type mockSelectQuery struct{}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Table(name string) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Order(order string) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Limit(n int) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Offset(n int) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Group(group string) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m }
|
||||||
|
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil }
|
||||||
|
func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil }
|
||||||
|
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil }
|
||||||
|
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil }
|
||||||
|
|
||||||
|
// Mock InsertQuery
|
||||||
|
type mockInsertQuery struct {
|
||||||
|
db *mockDatabase
|
||||||
|
table string
|
||||||
|
values map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m }
|
||||||
|
func (m *mockInsertQuery) Table(name string) InsertQuery {
|
||||||
|
m.table = name
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery {
|
||||||
|
if m.values == nil {
|
||||||
|
m.values = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
m.values[column] = value
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
|
||||||
|
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
|
||||||
|
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
|
||||||
|
// Record the insert call
|
||||||
|
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
||||||
|
m.db.lastID++
|
||||||
|
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock UpdateQuery
|
||||||
|
type mockUpdateQuery struct {
|
||||||
|
db *mockDatabase
|
||||||
|
table string
|
||||||
|
setValues map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m }
|
||||||
|
func (m *mockUpdateQuery) Table(name string) UpdateQuery {
|
||||||
|
m.table = name
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m }
|
||||||
|
func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery {
|
||||||
|
m.setValues = values
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m }
|
||||||
|
func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m }
|
||||||
|
func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) {
|
||||||
|
// Record the update call
|
||||||
|
m.db.updateCalls = append(m.db.updateCalls, m.setValues)
|
||||||
|
return &mockResult{rowsAffected: 1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock DeleteQuery
|
||||||
|
type mockDeleteQuery struct {
|
||||||
|
db *mockDatabase
|
||||||
|
table string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m }
|
||||||
|
func (m *mockDeleteQuery) Table(name string) DeleteQuery {
|
||||||
|
m.table = name
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m }
|
||||||
|
func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) {
|
||||||
|
// Record the delete call
|
||||||
|
m.db.deleteCalls = append(m.db.deleteCalls, m.table)
|
||||||
|
return &mockResult{rowsAffected: 1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock Result
|
||||||
|
type mockResult struct {
|
||||||
|
lastID int64
|
||||||
|
rowsAffected int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil }
|
||||||
|
func (m *mockResult) RowsAffected() int64 { return m.rowsAffected }
|
||||||
|
|
||||||
|
// Mock ModelRegistry
|
||||||
|
type mockModelRegistry struct{}
|
||||||
|
|
||||||
|
func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil }
|
||||||
|
func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil }
|
||||||
|
func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil }
|
||||||
|
func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) }
|
||||||
|
|
||||||
|
// Mock RelationshipInfoProvider
|
||||||
|
type mockRelationshipProvider struct {
|
||||||
|
relationships map[string]*RelationshipInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockRelationshipProvider() *mockRelationshipProvider {
|
||||||
|
return &mockRelationshipProvider{
|
||||||
|
relationships: make(map[string]*RelationshipInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
|
||||||
|
key := modelType.Name() + "." + relationName
|
||||||
|
return m.relationships[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) {
|
||||||
|
key := modelTypeName + "." + relationName
|
||||||
|
m.relationships[key] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Models
|
||||||
|
type Department struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Employees []*Employee `json:"employees,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Department) TableName() string { return "departments" }
|
||||||
|
func (d Department) GetIDName() string { return "ID" }
|
||||||
|
|
||||||
|
type Employee struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
DepartmentID int64 `json:"department_id"`
|
||||||
|
Tasks []*Task `json:"tasks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Employee) TableName() string { return "employees" }
|
||||||
|
func (e Employee) GetIDName() string { return "ID" }
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
EmployeeID int64 `json:"employee_id"`
|
||||||
|
Comments []*Comment `json:"comments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Task) TableName() string { return "tasks" }
|
||||||
|
func (t Task) GetIDName() string { return "ID" }
|
||||||
|
|
||||||
|
type Comment struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
TaskID int64 `json:"task_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Comment) TableName() string { return "comments" }
|
||||||
|
func (c Comment) GetIDName() string { return "ID" }
|
||||||
|
|
||||||
|
// Test Cases
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
// Register Department -> Employees relationship
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"name": "Jane Smith",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ID == nil {
|
||||||
|
t.Error("Expected result.ID to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify department was inserted
|
||||||
|
if len(db.insertCalls) != 3 {
|
||||||
|
t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify first insert is department
|
||||||
|
if db.insertCalls[0]["name"] != "Engineering" {
|
||||||
|
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify employees were inserted with foreign key
|
||||||
|
if db.insertCalls[1]["department_id"] == nil {
|
||||||
|
t.Error("Expected employee to have department_id set")
|
||||||
|
}
|
||||||
|
if db.insertCalls[2]["department_id"] == nil {
|
||||||
|
t.Error("Expected employee to have department_id set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
// Register relationships
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||||
|
FieldName: "Tasks",
|
||||||
|
JSONName: "tasks",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "EmployeeID",
|
||||||
|
RelatedModel: Task{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
"tasks": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"title": "Task 1",
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"title": "Task 2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ID == nil {
|
||||||
|
t.Error("Expected result.ID to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify: 1 dept + 1 employee + 2 tasks = 4 inserts
|
||||||
|
if len(db.insertCalls) != 4 {
|
||||||
|
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify department
|
||||||
|
if db.insertCalls[0]["name"] != "Engineering" {
|
||||||
|
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify employee has department_id
|
||||||
|
if db.insertCalls[1]["department_id"] == nil {
|
||||||
|
t.Error("Expected employee to have department_id set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify tasks have employee_id
|
||||||
|
if db.insertCalls[2]["employee_id"] == nil {
|
||||||
|
t.Error("Expected task to have employee_id set")
|
||||||
|
}
|
||||||
|
if db.insertCalls[3]["employee_id"] == nil {
|
||||||
|
t.Error("Expected task to have employee_id set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"_request": "update",
|
||||||
|
"ID": int64(10), // Use capital ID to match struct field
|
||||||
|
"name": "John Updated",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify department was inserted (1 insert)
|
||||||
|
// Employee should be updated (1 update)
|
||||||
|
if len(db.insertCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(db.updateCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify update data
|
||||||
|
if db.updateCalls[0]["name"] != "John Updated" {
|
||||||
|
t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
// Data with only _request field for nested employee
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"_request": "insert",
|
||||||
|
// No other fields besides _request
|
||||||
|
// Note: Foreign key will be injected, so employee WILL be inserted
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Department + Employee (with injected FK) = 2 inserts
|
||||||
|
if len(db.insertCalls) != 2 {
|
||||||
|
t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.insertCalls[0]["name"] != "Engineering" {
|
||||||
|
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify employee has foreign key
|
||||||
|
if db.insertCalls[1]["department_id"] == nil {
|
||||||
|
t.Error("Expected employee to have department_id injected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_Update(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"ID": int64(1), // Use capital ID to match struct field
|
||||||
|
"name": "Engineering Updated",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"_request": "insert",
|
||||||
|
"name": "New Employee",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"update",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ID != int64(1) {
|
||||||
|
t.Errorf("Expected result.ID to be 1, got %v", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify department was updated
|
||||||
|
if len(db.updateCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 update call, got %d", len(db.updateCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify new employee was inserted
|
||||||
|
if len(db.insertCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_Delete(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"ID": int64(1), // Use capital ID to match struct field
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"_request": "delete",
|
||||||
|
"ID": int64(10), // Use capital ID
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"_request": "delete",
|
||||||
|
"ID": int64(11), // Use capital ID
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"delete",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify employees were deleted first, then department
|
||||||
|
// 2 employees + 1 department = 3 deletes
|
||||||
|
if len(db.deleteCalls) != 3 {
|
||||||
|
t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
// Register 3-level relationships
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||||
|
FieldName: "Tasks",
|
||||||
|
JSONName: "tasks",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "EmployeeID",
|
||||||
|
RelatedModel: Task{},
|
||||||
|
})
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{
|
||||||
|
FieldName: "Comments",
|
||||||
|
JSONName: "comments",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "TaskID",
|
||||||
|
RelatedModel: Comment{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"tasks": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"title": "Task 1",
|
||||||
|
"comments": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"text": "Great work!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts
|
||||||
|
if len(db.insertCalls) != 4 {
|
||||||
|
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify department
|
||||||
|
if db.insertCalls[0]["name"] != "Engineering" {
|
||||||
|
t.Error("Expected department to be inserted first")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify employee has department_id
|
||||||
|
if db.insertCalls[1]["department_id"] == nil {
|
||||||
|
t.Error("Expected employee to have department_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify task has employee_id
|
||||||
|
if db.insertCalls[2]["employee_id"] == nil {
|
||||||
|
t.Error("Expected task to have employee_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify comment has task_id
|
||||||
|
if db.insertCalls[3]["task_id"] == nil {
|
||||||
|
t.Error("Expected comment to have task_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInjectForeignKeys(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
}
|
||||||
|
|
||||||
|
parentIDs := map[string]interface{}{
|
||||||
|
"department": int64(5),
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(Employee{})
|
||||||
|
|
||||||
|
processor.injectForeignKeys(data, modelType, parentIDs)
|
||||||
|
|
||||||
|
// Should inject department_id based on the "department" key in parentIDs
|
||||||
|
if data["department_id"] == nil {
|
||||||
|
t.Error("Expected department_id to be injected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if data["department_id"] != int64(5) {
|
||||||
|
t.Errorf("Expected department_id to be 5, got %v", data["department_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyName(t *testing.T) {
|
||||||
|
dept := Department{}
|
||||||
|
pkName := reflection.GetPrimaryKeyName(dept)
|
||||||
|
|
||||||
|
if pkName != "ID" {
|
||||||
|
t.Errorf("Expected primary key name 'ID', got '%s'", pkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with pointer
|
||||||
|
pkName2 := reflection.GetPrimaryKeyName(&dept)
|
||||||
|
if pkName2 != "ID" {
|
||||||
|
t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -130,6 +130,9 @@ func validateWhereClauseSecurity(where string) error {
|
|||||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||||
|
//
|
||||||
|
// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators
|
||||||
|
// to prevent OR logic from escaping and affecting the entire query incorrectly.
|
||||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||||
if where == "" {
|
if where == "" {
|
||||||
return ""
|
return ""
|
||||||
@@ -143,8 +146,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strip outer parentheses and re-trim
|
// Check if the original clause has outer parentheses and contains OR operators
|
||||||
where = stripOuterParentheses(where)
|
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
|
||||||
|
hasOuterParens := false
|
||||||
|
if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' {
|
||||||
|
_, hasOuterParens = stripOneMatchingOuterParen(where)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip outer parentheses and re-trim for processing
|
||||||
|
whereWithoutParens := stripOuterParentheses(where)
|
||||||
|
shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens)
|
||||||
|
|
||||||
|
// Use the stripped version for processing
|
||||||
|
where = whereWithoutParens
|
||||||
|
|
||||||
// Get valid columns from the model if tableName is provided
|
// Get valid columns from the model if tableName is provided
|
||||||
var validColumns map[string]bool
|
var validColumns map[string]bool
|
||||||
@@ -166,6 +180,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add join aliases as allowed prefixes
|
||||||
|
for _, alias := range options[0].JoinAliases {
|
||||||
|
if alias != "" {
|
||||||
|
allowedPrefixes[alias] = true
|
||||||
|
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split by AND to handle multiple conditions
|
// Split by AND to handle multiple conditions
|
||||||
@@ -221,7 +243,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
|
|
||||||
result := strings.Join(validConditions, " AND ")
|
result := strings.Join(validConditions, " AND ")
|
||||||
|
|
||||||
if result != where {
|
// If the original clause had outer parentheses and contains OR operators,
|
||||||
|
// restore the outer parentheses to prevent OR logic from escaping
|
||||||
|
if shouldPreserveParens {
|
||||||
|
result = "(" + result + ")"
|
||||||
|
logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != where && !shouldPreserveParens {
|
||||||
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,6 +311,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) {
|
|||||||
return strings.TrimSpace(s[1 : len(s)-1]), true
|
return strings.TrimSpace(s[1 : len(s)-1]), true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses
|
||||||
|
// to prevent OR logic from escaping. It checks if the clause already has
|
||||||
|
// matching outer parentheses and only adds them if they don't exist.
|
||||||
|
//
|
||||||
|
// This is particularly important for OR conditions and complex filters where
|
||||||
|
// the absence of parentheses could cause the logic to escape and affect
|
||||||
|
// the entire query incorrectly.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - clause: The SQL clause to check and potentially wrap
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The clause with guaranteed outer parentheses, or empty string if input is empty
|
||||||
|
func EnsureOuterParentheses(clause string) string {
|
||||||
|
if clause == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
clause = strings.TrimSpace(clause)
|
||||||
|
if clause == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the clause already has matching outer parentheses
|
||||||
|
_, hasOuterParens := stripOneMatchingOuterParen(clause)
|
||||||
|
|
||||||
|
// If it already has matching outer parentheses, return as-is
|
||||||
|
if hasOuterParens {
|
||||||
|
return clause
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, wrap it in parentheses
|
||||||
|
return "(" + clause + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsTopLevelOR checks if a SQL clause contains OR operators at the top level
|
||||||
|
// (i.e., not inside parentheses or subqueries). This is used to determine if
|
||||||
|
// outer parentheses should be preserved to prevent OR logic from escaping.
|
||||||
|
func containsTopLevelOR(clause string) bool {
|
||||||
|
if clause == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
depth := 0
|
||||||
|
inSingleQuote := false
|
||||||
|
inDoubleQuote := false
|
||||||
|
lowerClause := strings.ToLower(clause)
|
||||||
|
|
||||||
|
for i := 0; i < len(clause); i++ {
|
||||||
|
ch := clause[i]
|
||||||
|
|
||||||
|
// Track quote state
|
||||||
|
if ch == '\'' && !inDoubleQuote {
|
||||||
|
inSingleQuote = !inSingleQuote
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' && !inSingleQuote {
|
||||||
|
inDoubleQuote = !inDoubleQuote
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if inside quotes
|
||||||
|
if inSingleQuote || inDoubleQuote {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track parenthesis depth
|
||||||
|
switch ch {
|
||||||
|
case '(':
|
||||||
|
depth++
|
||||||
|
case ')':
|
||||||
|
depth--
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only check for OR at depth 0 (not inside parentheses)
|
||||||
|
if depth == 0 && i+4 <= len(clause) {
|
||||||
|
// Check for " OR " (case-insensitive)
|
||||||
|
substring := lowerClause[i : i+4]
|
||||||
|
if substring == " or " {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||||
func splitByAND(where string) []string {
|
func splitByAND(where string) []string {
|
||||||
|
|||||||
103
pkg/common/sql_helpers_tablename_test.go
Normal file
103
pkg/common/sql_helpers_tablename_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
|
||||||
|
// are correctly handled when the tableName parameter matches the prefix
|
||||||
|
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
options *RequestOptions
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Correct table prefix should not be changed",
|
||||||
|
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
options: nil,
|
||||||
|
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wrong table prefix should be fixed",
|
||||||
|
where: "wrong_table.rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
options: nil,
|
||||||
|
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Relation name should not replace correct table prefix",
|
||||||
|
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
options: &RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||||
|
TableName: "mastertaskitem",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unqualified column should remain unqualified",
|
||||||
|
where: "rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
options: nil,
|
||||||
|
expected: "rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
|
||||||
|
tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
|
||||||
|
// are correctly added to unqualified columns
|
||||||
|
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Add prefix to unqualified column",
|
||||||
|
where: "rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't change already qualified column",
|
||||||
|
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Don't change qualified column with different table",
|
||||||
|
where: "other_table.rid_something is null",
|
||||||
|
tableName: "mastertaskitem",
|
||||||
|
expected: "other_table.rid_something is null",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
|
||||||
|
tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureOuterParentheses(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no parentheses",
|
||||||
|
input: "status = 'active'",
|
||||||
|
expected: "(status = 'active')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "already has outer parentheses",
|
||||||
|
input: "(status = 'active')",
|
||||||
|
expected: "(status = 'active')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR condition without parentheses",
|
||||||
|
input: "status = 'active' OR status = 'pending'",
|
||||||
|
expected: "(status = 'active' OR status = 'pending')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR condition with parentheses",
|
||||||
|
input: "(status = 'active' OR status = 'pending')",
|
||||||
|
expected: "(status = 'active' OR status = 'pending')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex condition with nested parentheses",
|
||||||
|
input: "(status = 'active' OR status = 'pending') AND (age > 18)",
|
||||||
|
expected: "((status = 'active' OR status = 'pending') AND (age > 18))",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace only",
|
||||||
|
input: " ",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatched parentheses - adds outer ones",
|
||||||
|
input: "(status = 'active' OR status = 'pending'",
|
||||||
|
expected: "((status = 'active' OR status = 'pending')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := EnsureOuterParentheses(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContainsTopLevelOR(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no OR operator",
|
||||||
|
input: "status = 'active' AND age > 18",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top-level OR",
|
||||||
|
input: "status = 'active' OR status = 'pending'",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR inside parentheses",
|
||||||
|
input: "age > 18 AND (status = 'active' OR status = 'pending')",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR in subquery",
|
||||||
|
input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR inside quotes",
|
||||||
|
input: "comment = 'this OR that'",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed - top-level OR and nested OR",
|
||||||
|
input: "name = 'test' OR (status = 'active' OR status = 'pending')",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase or",
|
||||||
|
input: "status = 'active' or status = 'pending'",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase OR",
|
||||||
|
input: "status = 'active' OR status = 'pending'",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := containsTopLevelOR(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "OR condition with outer parentheses - preserved",
|
||||||
|
where: "(status = 'active' OR status = 'pending')",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AND condition with outer parentheses - stripped (no OR)",
|
||||||
|
where: "(status = 'active' AND age > 18)",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex OR with nested conditions",
|
||||||
|
where: "((status = 'active' OR status = 'pending') AND age > 18)",
|
||||||
|
tableName: "users",
|
||||||
|
// Outer parens are stripped, but inner parens with OR are preserved
|
||||||
|
expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause",
|
||||||
|
where: "status = 'active' OR status = 'pending'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' OR users.status = 'pending'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple OR with parentheses - preserved",
|
||||||
|
where: "(users.status = 'active' OR users.status = 'pending')",
|
||||||
|
tableName: "users",
|
||||||
|
// Already has correct prefixes, parentheses preserved
|
||||||
|
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||||
|
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ type RequestOptions struct {
|
|||||||
CursorForward string `json:"cursor_forward"`
|
CursorForward string `json:"cursor_forward"`
|
||||||
CursorBackward string `json:"cursor_backward"`
|
CursorBackward string `json:"cursor_backward"`
|
||||||
FetchRowNumber *string `json:"fetch_row_number"`
|
FetchRowNumber *string `json:"fetch_row_number"`
|
||||||
|
|
||||||
|
// Join table aliases (used for validation of prefixed columns in filters/sorts)
|
||||||
|
// Not serialized to JSON as it's internal validation state
|
||||||
|
JoinAliases []string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Parameter struct {
|
type Parameter struct {
|
||||||
@@ -33,6 +37,7 @@ type Parameter struct {
|
|||||||
|
|
||||||
type PreloadOption struct {
|
type PreloadOption struct {
|
||||||
Relation string `json:"relation"`
|
Relation string `json:"relation"`
|
||||||
|
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
|
||||||
Columns []string `json:"columns"`
|
Columns []string `json:"columns"`
|
||||||
OmitColumns []string `json:"omit_columns"`
|
OmitColumns []string `json:"omit_columns"`
|
||||||
Sort []SortOption `json:"sort"`
|
Sort []SortOption `json:"sort"`
|
||||||
@@ -48,6 +53,11 @@ type PreloadOption struct {
|
|||||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||||
|
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
|
||||||
|
|
||||||
|
// Custom SQL JOINs from XFiles - used when preload needs additional joins
|
||||||
|
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
|
||||||
|
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
@@ -111,3 +121,14 @@ type TableMetadata struct {
|
|||||||
Columns []Column `json:"columns"`
|
Columns []Column `json:"columns"`
|
||||||
Relations []string `json:"relations"`
|
Relations []string `json:"relations"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RelationshipInfo contains information about a model relationship
|
||||||
|
type RelationshipInfo struct {
|
||||||
|
FieldName string `json:"field_name"`
|
||||||
|
JSONName string `json:"json_name"`
|
||||||
|
RelationType string `json:"relation_type"` // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||||
|
ForeignKey string `json:"foreign_key"`
|
||||||
|
References string `json:"references"`
|
||||||
|
JoinTable string `json:"join_table"`
|
||||||
|
RelatedModel interface{} `json:"related_model"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -237,17 +237,31 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
for _, sort := range options.Sort {
|
for _, sort := range options.Sort {
|
||||||
if v.IsValidColumn(sort.Column) {
|
if v.IsValidColumn(sort.Column) {
|
||||||
validSorts = append(validSorts, sort)
|
validSorts = append(validSorts, sort)
|
||||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
} else {
|
||||||
|
foundJoin := false
|
||||||
|
for _, j := range options.JoinAliases {
|
||||||
|
if strings.Contains(sort.Column, j) {
|
||||||
|
foundJoin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundJoin {
|
||||||
|
validSorts = append(validSorts, sort)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
// Allow sort by expression/subquery, but validate for security
|
// Allow sort by expression/subquery, but validate for security
|
||||||
if IsSafeSortExpression(sort.Column) {
|
if IsSafeSortExpression(sort.Column) {
|
||||||
validSorts = append(validSorts, sort)
|
validSorts = append(validSorts, sort)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
filtered.Sort = validSorts
|
filtered.Sort = validSorts
|
||||||
|
|
||||||
// Filter Preload columns
|
// Filter Preload columns
|
||||||
@@ -258,15 +272,31 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
|
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||||
|
filteredPreload.SqlJoins = preload.SqlJoins
|
||||||
|
filteredPreload.JoinAliases = preload.JoinAliases
|
||||||
|
|
||||||
// Filter preload filters
|
// Filter preload filters
|
||||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if v.IsValidColumn(filter.Column) {
|
||||||
validPreloadFilters = append(validPreloadFilters, filter)
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
|
} else {
|
||||||
|
// Check if the filter column references a joined table alias
|
||||||
|
foundJoin := false
|
||||||
|
for _, alias := range preload.JoinAliases {
|
||||||
|
if strings.Contains(filter.Column, alias) {
|
||||||
|
foundJoin = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundJoin {
|
||||||
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
filteredPreload.Filters = validPreloadFilters
|
filteredPreload.Filters = validPreloadFilters
|
||||||
|
|
||||||
// Filter preload sort columns
|
// Filter preload sort columns
|
||||||
@@ -291,6 +321,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
}
|
}
|
||||||
filtered.Preload = validPreloads
|
filtered.Preload = validPreloads
|
||||||
|
|
||||||
|
// Clear JoinAliases - this is an internal validation field and should not be persisted
|
||||||
|
filtered.JoinAliases = nil
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Columns: []string{"id", "name"},
|
||||||
|
// Set JoinAliases - this should be cleared by FilterRequestOptions
|
||||||
|
JoinAliases: []string{"d", "u", "r"},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// Verify that JoinAliases was cleared (internal field should not persist)
|
||||||
|
if filtered.JoinAliases != nil {
|
||||||
|
t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that other fields are still properly filtered
|
||||||
|
if len(filtered.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsSafeSortExpression(t *testing.T) {
|
func TestIsSafeSortExpression(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -4,20 +4,28 @@ import "time"
|
|||||||
|
|
||||||
// Config represents the complete application configuration
|
// Config represents the complete application configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Servers ServersConfig `mapstructure:"servers"`
|
||||||
Tracing TracingConfig `mapstructure:"tracing"`
|
Tracing TracingConfig `mapstructure:"tracing"`
|
||||||
Cache CacheConfig `mapstructure:"cache"`
|
Cache CacheConfig `mapstructure:"cache"`
|
||||||
Logger LoggerConfig `mapstructure:"logger"`
|
Logger LoggerConfig `mapstructure:"logger"`
|
||||||
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
||||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
|
||||||
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||||
|
DBManager DBManagerConfig `mapstructure:"dbmanager"`
|
||||||
|
Paths PathsConfig `mapstructure:"paths"`
|
||||||
|
Extensions map[string]interface{} `mapstructure:"extensions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServersConfig contains configuration for the server manager
|
||||||
type ServerConfig struct {
|
type ServersConfig struct {
|
||||||
Addr string `mapstructure:"addr"`
|
// DefaultServer is the name of the default server to use
|
||||||
|
DefaultServer string `mapstructure:"default_server"`
|
||||||
|
|
||||||
|
// Instances is a map of server name to server configuration
|
||||||
|
Instances map[string]ServerInstanceConfig `mapstructure:"instances"`
|
||||||
|
|
||||||
|
// Global timeout defaults (can be overridden per instance)
|
||||||
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
|
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
|
||||||
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
|
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
|
||||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||||
@@ -25,6 +33,51 @@ type ServerConfig struct {
|
|||||||
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ServerInstanceConfig defines configuration for a single server instance
|
||||||
|
type ServerInstanceConfig struct {
|
||||||
|
// Name is the unique name of this server instance
|
||||||
|
Name string `mapstructure:"name"`
|
||||||
|
|
||||||
|
// Host is the host to bind to (e.g., "localhost", "0.0.0.0", "")
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
|
||||||
|
// Port is the port number to listen on
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
|
||||||
|
// Description is a human-readable description of this server
|
||||||
|
Description string `mapstructure:"description"`
|
||||||
|
|
||||||
|
// GZIP enables GZIP compression middleware
|
||||||
|
GZIP bool `mapstructure:"gzip"`
|
||||||
|
|
||||||
|
// TLS/HTTPS configuration options (mutually exclusive)
|
||||||
|
// Option 1: Provide certificate and key files directly
|
||||||
|
SSLCert string `mapstructure:"ssl_cert"`
|
||||||
|
SSLKey string `mapstructure:"ssl_key"`
|
||||||
|
|
||||||
|
// Option 2: Use self-signed certificate (for development/testing)
|
||||||
|
SelfSignedSSL bool `mapstructure:"self_signed_ssl"`
|
||||||
|
|
||||||
|
// Option 3: Use Let's Encrypt / AutoTLS
|
||||||
|
AutoTLS bool `mapstructure:"auto_tls"`
|
||||||
|
AutoTLSDomains []string `mapstructure:"auto_tls_domains"`
|
||||||
|
AutoTLSCacheDir string `mapstructure:"auto_tls_cache_dir"`
|
||||||
|
AutoTLSEmail string `mapstructure:"auto_tls_email"`
|
||||||
|
|
||||||
|
// Timeout configurations (overrides global defaults)
|
||||||
|
ShutdownTimeout *time.Duration `mapstructure:"shutdown_timeout"`
|
||||||
|
DrainTimeout *time.Duration `mapstructure:"drain_timeout"`
|
||||||
|
ReadTimeout *time.Duration `mapstructure:"read_timeout"`
|
||||||
|
WriteTimeout *time.Duration `mapstructure:"write_timeout"`
|
||||||
|
IdleTimeout *time.Duration `mapstructure:"idle_timeout"`
|
||||||
|
|
||||||
|
// Tags for organization and filtering
|
||||||
|
Tags map[string]string `mapstructure:"tags"`
|
||||||
|
|
||||||
|
// ExternalURLs are additional URLs that this server instance is accessible from (for CORS) for proxy setups
|
||||||
|
ExternalURLs []string `mapstructure:"external_urls"`
|
||||||
|
}
|
||||||
|
|
||||||
// TracingConfig holds OpenTelemetry tracing configuration
|
// TracingConfig holds OpenTelemetry tracing configuration
|
||||||
type TracingConfig struct {
|
type TracingConfig struct {
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
@@ -76,11 +129,6 @@ type CORSConfig struct {
|
|||||||
MaxAge int `mapstructure:"max_age"`
|
MaxAge int `mapstructure:"max_age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig holds database configuration (primarily for testing)
|
|
||||||
type DatabaseConfig struct {
|
|
||||||
URL string `mapstructure:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorTrackingConfig holds error tracking configuration
|
// ErrorTrackingConfig holds error tracking configuration
|
||||||
type ErrorTrackingConfig struct {
|
type ErrorTrackingConfig struct {
|
||||||
Enabled bool `mapstructure:"enabled"`
|
Enabled bool `mapstructure:"enabled"`
|
||||||
@@ -141,3 +189,8 @@ type EventBrokerRetryPolicyConfig struct {
|
|||||||
MaxDelay time.Duration `mapstructure:"max_delay"`
|
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||||
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PathsConfig contains configuration for named file system paths
|
||||||
|
// This is a map of path name to file system path
|
||||||
|
// Example: "data_dir": "/var/lib/myapp/data"
|
||||||
|
type PathsConfig map[string]string
|
||||||
|
|||||||
264
pkg/config/dbmanager.go
Normal file
264
pkg/config/dbmanager.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DBManagerConfig contains configuration for the database connection manager
|
||||||
|
type DBManagerConfig struct {
|
||||||
|
// DefaultConnection is the name of the default connection to use
|
||||||
|
DefaultConnection string `mapstructure:"default_connection"`
|
||||||
|
|
||||||
|
// Connections is a map of connection name to connection configuration
|
||||||
|
Connections map[string]DBConnectionConfig `mapstructure:"connections"`
|
||||||
|
|
||||||
|
// Global connection pool defaults
|
||||||
|
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||||
|
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||||
|
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||||
|
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
|
||||||
|
|
||||||
|
// Retry policy
|
||||||
|
RetryAttempts int `mapstructure:"retry_attempts"`
|
||||||
|
RetryDelay time.Duration `mapstructure:"retry_delay"`
|
||||||
|
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||||
|
|
||||||
|
// Health checks
|
||||||
|
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
|
||||||
|
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBConnectionConfig defines configuration for a single database connection
|
||||||
|
type DBConnectionConfig struct {
|
||||||
|
// Name is the unique name of this connection
|
||||||
|
Name string `mapstructure:"name"`
|
||||||
|
|
||||||
|
// Type is the database type (postgres, sqlite, mssql, mongodb)
|
||||||
|
Type string `mapstructure:"type"`
|
||||||
|
|
||||||
|
// DSN is the complete Data Source Name / connection string
|
||||||
|
// If provided, this takes precedence over individual connection parameters
|
||||||
|
DSN string `mapstructure:"dsn"`
|
||||||
|
|
||||||
|
// Connection parameters (used if DSN is not provided)
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
User string `mapstructure:"user"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
Database string `mapstructure:"database"`
|
||||||
|
|
||||||
|
// PostgreSQL/MSSQL specific
|
||||||
|
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
|
||||||
|
Schema string `mapstructure:"schema"` // Default schema
|
||||||
|
|
||||||
|
// SQLite specific
|
||||||
|
FilePath string `mapstructure:"filepath"`
|
||||||
|
|
||||||
|
// MongoDB specific
|
||||||
|
AuthSource string `mapstructure:"auth_source"`
|
||||||
|
ReplicaSet string `mapstructure:"replica_set"`
|
||||||
|
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
|
||||||
|
|
||||||
|
// Connection pool settings (overrides global defaults)
|
||||||
|
MaxOpenConns *int `mapstructure:"max_open_conns"`
|
||||||
|
MaxIdleConns *int `mapstructure:"max_idle_conns"`
|
||||||
|
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
|
||||||
|
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
|
||||||
|
|
||||||
|
// Timeouts
|
||||||
|
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
|
||||||
|
QueryTimeout time.Duration `mapstructure:"query_timeout"`
|
||||||
|
|
||||||
|
// Features
|
||||||
|
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||||
|
EnableMetrics bool `mapstructure:"enable_metrics"`
|
||||||
|
EnableLogging bool `mapstructure:"enable_logging"`
|
||||||
|
|
||||||
|
// DefaultORM specifies which ORM to use for the Database() method
|
||||||
|
// Options: "bun", "gorm", "native"
|
||||||
|
DefaultORM string `mapstructure:"default_orm"`
|
||||||
|
|
||||||
|
// Tags for organization and filtering
|
||||||
|
Tags map[string]string `mapstructure:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToManagerConfig converts config.DBManagerConfig to dbmanager.ManagerConfig
|
||||||
|
// This is used to avoid circular dependencies
|
||||||
|
func (c *DBManagerConfig) ToManagerConfig() interface{} {
|
||||||
|
// This will be implemented in the dbmanager package
|
||||||
|
// to convert from config types to dbmanager types
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulateFromDSN parses a DSN and populates the connection fields
|
||||||
|
func (cc *DBConnectionConfig) PopulateFromDSN() error {
|
||||||
|
if cc.DSN == "" {
|
||||||
|
return nil // Nothing to populate
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cc.Type {
|
||||||
|
case "postgres":
|
||||||
|
return cc.populatePostgresDSN()
|
||||||
|
case "mongodb":
|
||||||
|
return cc.populateMongoDSN()
|
||||||
|
case "mssql":
|
||||||
|
return cc.populateMSSQLDSN()
|
||||||
|
case "sqlite":
|
||||||
|
return cc.populateSQLiteDSN()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot parse DSN for unsupported database type: %s", cc.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// populatePostgresDSN parses PostgreSQL DSN format
|
||||||
|
// Example: host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable
|
||||||
|
func (cc *DBConnectionConfig) populatePostgresDSN() error {
|
||||||
|
parts := strings.Fields(cc.DSN)
|
||||||
|
for _, part := range parts {
|
||||||
|
kv := strings.SplitN(part, "=", 2)
|
||||||
|
if len(kv) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, value := kv[0], kv[1]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "host":
|
||||||
|
cc.Host = value
|
||||||
|
case "port":
|
||||||
|
port, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port in DSN: %w", err)
|
||||||
|
}
|
||||||
|
cc.Port = port
|
||||||
|
case "user":
|
||||||
|
cc.User = value
|
||||||
|
case "password":
|
||||||
|
cc.Password = value
|
||||||
|
case "dbname":
|
||||||
|
cc.Database = value
|
||||||
|
case "sslmode":
|
||||||
|
cc.SSLMode = value
|
||||||
|
case "search_path":
|
||||||
|
cc.Schema = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateMongoDSN parses MongoDB DSN format
|
||||||
|
// Example: mongodb://user:password@host:port/database?authSource=admin&replicaSet=rs0
|
||||||
|
func (cc *DBConnectionConfig) populateMongoDSN() error {
|
||||||
|
u, err := url.Parse(cc.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid MongoDB DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract user and password
|
||||||
|
if u.User != nil {
|
||||||
|
cc.User = u.User.Username()
|
||||||
|
if password, ok := u.User.Password(); ok {
|
||||||
|
cc.Password = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract host and port
|
||||||
|
if u.Host != "" {
|
||||||
|
host := u.Host
|
||||||
|
if strings.Contains(host, ":") {
|
||||||
|
hostPort := strings.SplitN(host, ":", 2)
|
||||||
|
cc.Host = hostPort[0]
|
||||||
|
if port, err := strconv.Atoi(hostPort[1]); err == nil {
|
||||||
|
cc.Port = port
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cc.Host = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract database
|
||||||
|
if u.Path != "" {
|
||||||
|
cc.Database = strings.TrimPrefix(u.Path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract query parameters
|
||||||
|
params := u.Query()
|
||||||
|
if authSource := params.Get("authSource"); authSource != "" {
|
||||||
|
cc.AuthSource = authSource
|
||||||
|
}
|
||||||
|
if replicaSet := params.Get("replicaSet"); replicaSet != "" {
|
||||||
|
cc.ReplicaSet = replicaSet
|
||||||
|
}
|
||||||
|
if readPref := params.Get("readPreference"); readPref != "" {
|
||||||
|
cc.ReadPreference = readPref
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateMSSQLDSN parses MSSQL DSN format
|
||||||
|
// Example: sqlserver://username:password@host:port?database=dbname&schema=dbo
|
||||||
|
func (cc *DBConnectionConfig) populateMSSQLDSN() error {
|
||||||
|
u, err := url.Parse(cc.DSN)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid MSSQL DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract user and password
|
||||||
|
if u.User != nil {
|
||||||
|
cc.User = u.User.Username()
|
||||||
|
if password, ok := u.User.Password(); ok {
|
||||||
|
cc.Password = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract host and port
|
||||||
|
if u.Host != "" {
|
||||||
|
host := u.Host
|
||||||
|
if strings.Contains(host, ":") {
|
||||||
|
hostPort := strings.SplitN(host, ":", 2)
|
||||||
|
cc.Host = hostPort[0]
|
||||||
|
if port, err := strconv.Atoi(hostPort[1]); err == nil {
|
||||||
|
cc.Port = port
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cc.Host = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract query parameters
|
||||||
|
params := u.Query()
|
||||||
|
if database := params.Get("database"); database != "" {
|
||||||
|
cc.Database = database
|
||||||
|
}
|
||||||
|
if schema := params.Get("schema"); schema != "" {
|
||||||
|
cc.Schema = schema
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateSQLiteDSN parses SQLite DSN format
|
||||||
|
// Example: /path/to/database.db or :memory:
|
||||||
|
func (cc *DBConnectionConfig) populateSQLiteDSN() error {
|
||||||
|
cc.FilePath = cc.DSN
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the DBManager configuration
|
||||||
|
func (c *DBManagerConfig) Validate() error {
|
||||||
|
if len(c.Connections) == 0 {
|
||||||
|
return fmt.Errorf("at least one connection must be configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.DefaultConnection != "" {
|
||||||
|
if _, ok := c.Connections[c.DefaultConnection]; !ok {
|
||||||
|
return fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -12,6 +12,16 @@ type Manager struct {
|
|||||||
v *viper.Viper
|
v *viper.Viper
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var configInstance *Manager
|
||||||
|
|
||||||
|
// GetConfigManager returns a singleton configuration manager instance
|
||||||
|
func GetConfigManager() *Manager {
|
||||||
|
if configInstance == nil {
|
||||||
|
configInstance = NewManager()
|
||||||
|
}
|
||||||
|
return configInstance
|
||||||
|
}
|
||||||
|
|
||||||
// NewManager creates a new configuration manager with defaults
|
// NewManager creates a new configuration manager with defaults
|
||||||
func NewManager() *Manager {
|
func NewManager() *Manager {
|
||||||
v := viper.New()
|
v := viper.New()
|
||||||
@@ -32,7 +42,8 @@ func NewManager() *Manager {
|
|||||||
// Set default values
|
// Set default values
|
||||||
setDefaults(v)
|
setDefaults(v)
|
||||||
|
|
||||||
return &Manager{v: v}
|
configInstance = &Manager{v: v}
|
||||||
|
return configInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManagerWithOptions creates a new configuration manager with custom options
|
// NewManagerWithOptions creates a new configuration manager with custom options
|
||||||
@@ -97,6 +108,31 @@ func (m *Manager) GetConfig() (*Config, error) {
|
|||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConfig sets the complete configuration
|
||||||
|
func (m *Manager) SetConfig(cfg *Config) error {
|
||||||
|
configMap := make(map[string]interface{})
|
||||||
|
|
||||||
|
// Marshal the config to a map structure that viper can use
|
||||||
|
if err := m.v.Unmarshal(&configMap); err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare config map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use viper's merge to apply the config
|
||||||
|
m.v.Set("servers", cfg.Servers)
|
||||||
|
m.v.Set("tracing", cfg.Tracing)
|
||||||
|
m.v.Set("cache", cfg.Cache)
|
||||||
|
m.v.Set("logger", cfg.Logger)
|
||||||
|
m.v.Set("error_tracking", cfg.ErrorTracking)
|
||||||
|
m.v.Set("middleware", cfg.Middleware)
|
||||||
|
m.v.Set("cors", cfg.CORS)
|
||||||
|
m.v.Set("event_broker", cfg.EventBroker)
|
||||||
|
m.v.Set("dbmanager", cfg.DBManager)
|
||||||
|
m.v.Set("paths", cfg.Paths)
|
||||||
|
m.v.Set("extensions", cfg.Extensions)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get returns a configuration value by key
|
// Get returns a configuration value by key
|
||||||
func (m *Manager) Get(key string) interface{} {
|
func (m *Manager) Get(key string) interface{} {
|
||||||
return m.v.Get(key)
|
return m.v.Get(key)
|
||||||
@@ -122,15 +158,32 @@ func (m *Manager) Set(key string, value interface{}) {
|
|||||||
m.v.Set(key, value)
|
m.v.Set(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveConfig writes the current configuration to the specified path
|
||||||
|
func (m *Manager) SaveConfig(path string) error {
|
||||||
|
if err := m.v.WriteConfigAs(path); err != nil {
|
||||||
|
return fmt.Errorf("failed to save config to %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// setDefaults sets default configuration values
|
// setDefaults sets default configuration values
|
||||||
func setDefaults(v *viper.Viper) {
|
func setDefaults(v *viper.Viper) {
|
||||||
// Server defaults
|
// Server defaults - new structure
|
||||||
v.SetDefault("server.addr", ":8080")
|
v.SetDefault("servers.default_server", "default")
|
||||||
v.SetDefault("server.shutdown_timeout", "30s")
|
|
||||||
v.SetDefault("server.drain_timeout", "25s")
|
// Global server timeout defaults
|
||||||
v.SetDefault("server.read_timeout", "10s")
|
v.SetDefault("servers.shutdown_timeout", "30s")
|
||||||
v.SetDefault("server.write_timeout", "10s")
|
v.SetDefault("servers.drain_timeout", "25s")
|
||||||
v.SetDefault("server.idle_timeout", "120s")
|
v.SetDefault("servers.read_timeout", "10s")
|
||||||
|
v.SetDefault("servers.write_timeout", "10s")
|
||||||
|
v.SetDefault("servers.idle_timeout", "120s")
|
||||||
|
|
||||||
|
// Default server instance
|
||||||
|
v.SetDefault("servers.instances.default.name", "default")
|
||||||
|
v.SetDefault("servers.instances.default.host", "")
|
||||||
|
v.SetDefault("servers.instances.default.port", 8080)
|
||||||
|
v.SetDefault("servers.instances.default.description", "Default HTTP server")
|
||||||
|
v.SetDefault("servers.instances.default.gzip", false)
|
||||||
|
|
||||||
// Tracing defaults
|
// Tracing defaults
|
||||||
v.SetDefault("tracing.enabled", false)
|
v.SetDefault("tracing.enabled", false)
|
||||||
@@ -166,6 +219,34 @@ func setDefaults(v *viper.Viper) {
|
|||||||
// Database defaults
|
// Database defaults
|
||||||
v.SetDefault("database.url", "")
|
v.SetDefault("database.url", "")
|
||||||
|
|
||||||
|
// Database Manager defaults
|
||||||
|
v.SetDefault("dbmanager.default_connection", "default")
|
||||||
|
v.SetDefault("dbmanager.max_open_conns", 25)
|
||||||
|
v.SetDefault("dbmanager.max_idle_conns", 5)
|
||||||
|
v.SetDefault("dbmanager.conn_max_lifetime", "30m")
|
||||||
|
v.SetDefault("dbmanager.conn_max_idle_time", "5m")
|
||||||
|
v.SetDefault("dbmanager.retry_attempts", 3)
|
||||||
|
v.SetDefault("dbmanager.retry_delay", "1s")
|
||||||
|
v.SetDefault("dbmanager.retry_max_delay", "10s")
|
||||||
|
v.SetDefault("dbmanager.health_check_interval", "30s")
|
||||||
|
v.SetDefault("dbmanager.enable_auto_reconnect", true)
|
||||||
|
|
||||||
|
// Default PostgreSQL connection
|
||||||
|
v.SetDefault("dbmanager.connections.default.name", "default")
|
||||||
|
v.SetDefault("dbmanager.connections.default.type", "postgres")
|
||||||
|
v.SetDefault("dbmanager.connections.default.host", "localhost")
|
||||||
|
v.SetDefault("dbmanager.connections.default.port", 5432)
|
||||||
|
v.SetDefault("dbmanager.connections.default.user", "postgres")
|
||||||
|
v.SetDefault("dbmanager.connections.default.password", "")
|
||||||
|
v.SetDefault("dbmanager.connections.default.database", "resolvespec")
|
||||||
|
v.SetDefault("dbmanager.connections.default.sslmode", "disable")
|
||||||
|
v.SetDefault("dbmanager.connections.default.connect_timeout", "10s")
|
||||||
|
v.SetDefault("dbmanager.connections.default.query_timeout", "30s")
|
||||||
|
v.SetDefault("dbmanager.connections.default.enable_tracing", false)
|
||||||
|
v.SetDefault("dbmanager.connections.default.enable_metrics", false)
|
||||||
|
v.SetDefault("dbmanager.connections.default.enable_logging", false)
|
||||||
|
v.SetDefault("dbmanager.connections.default.default_orm", "bun")
|
||||||
|
|
||||||
// Event Broker defaults
|
// Event Broker defaults
|
||||||
v.SetDefault("event_broker.enabled", false)
|
v.SetDefault("event_broker.enabled", false)
|
||||||
v.SetDefault("event_broker.provider", "memory")
|
v.SetDefault("event_broker.provider", "memory")
|
||||||
@@ -200,4 +281,13 @@ func setDefaults(v *viper.Viper) {
|
|||||||
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||||
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||||
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||||
|
|
||||||
|
// Paths defaults (common directory paths)
|
||||||
|
v.SetDefault("paths.data_dir", "./data")
|
||||||
|
v.SetDefault("paths.config_dir", "./config")
|
||||||
|
v.SetDefault("paths.logs_dir", "./logs")
|
||||||
|
v.SetDefault("paths.temp_dir", "./tmp")
|
||||||
|
|
||||||
|
// Extensions defaults (empty map)
|
||||||
|
v.SetDefault("extensions", map[string]interface{}{})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ func TestDefaultValues(t *testing.T) {
|
|||||||
got interface{}
|
got interface{}
|
||||||
expected interface{}
|
expected interface{}
|
||||||
}{
|
}{
|
||||||
{"server.addr", cfg.Server.Addr, ":8080"},
|
{"servers.default_server", cfg.Servers.DefaultServer, "default"},
|
||||||
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
|
{"servers.shutdown_timeout", cfg.Servers.ShutdownTimeout, 30 * time.Second},
|
||||||
{"tracing.enabled", cfg.Tracing.Enabled, false},
|
{"tracing.enabled", cfg.Tracing.Enabled, false},
|
||||||
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
|
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
|
||||||
{"cache.provider", cfg.Cache.Provider, "memory"},
|
{"cache.provider", cfg.Cache.Provider, "memory"},
|
||||||
@@ -46,6 +46,18 @@ func TestDefaultValues(t *testing.T) {
|
|||||||
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
|
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test default server instance
|
||||||
|
defaultServer, ok := cfg.Servers.Instances["default"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Default server instance not found")
|
||||||
|
}
|
||||||
|
if defaultServer.Port != 8080 {
|
||||||
|
t.Errorf("default server port: got %d, want 8080", defaultServer.Port)
|
||||||
|
}
|
||||||
|
if defaultServer.Name != "default" {
|
||||||
|
t.Errorf("default server name: got %s, want default", defaultServer.Name)
|
||||||
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if tt.got != tt.expected {
|
if tt.got != tt.expected {
|
||||||
@@ -57,12 +69,12 @@ func TestDefaultValues(t *testing.T) {
|
|||||||
|
|
||||||
func TestEnvironmentVariableOverrides(t *testing.T) {
|
func TestEnvironmentVariableOverrides(t *testing.T) {
|
||||||
// Set environment variables
|
// Set environment variables
|
||||||
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
|
os.Setenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT", "9090")
|
||||||
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
|
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
|
||||||
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
|
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
|
||||||
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
|
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
|
||||||
defer func() {
|
defer func() {
|
||||||
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
|
os.Unsetenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT")
|
||||||
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
|
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
|
||||||
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
|
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
|
||||||
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
|
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
|
||||||
@@ -84,7 +96,6 @@ func TestEnvironmentVariableOverrides(t *testing.T) {
|
|||||||
got interface{}
|
got interface{}
|
||||||
expected interface{}
|
expected interface{}
|
||||||
}{
|
}{
|
||||||
{"server.addr", cfg.Server.Addr, ":9090"},
|
|
||||||
{"tracing.enabled", cfg.Tracing.Enabled, true},
|
{"tracing.enabled", cfg.Tracing.Enabled, true},
|
||||||
{"cache.provider", cfg.Cache.Provider, "redis"},
|
{"cache.provider", cfg.Cache.Provider, "redis"},
|
||||||
{"logger.dev", cfg.Logger.Dev, true},
|
{"logger.dev", cfg.Logger.Dev, true},
|
||||||
@@ -97,11 +108,17 @@ func TestEnvironmentVariableOverrides(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test server port override
|
||||||
|
defaultServer := cfg.Servers.Instances["default"]
|
||||||
|
if defaultServer.Port != 9090 {
|
||||||
|
t.Errorf("server port: got %d, want 9090", defaultServer.Port)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProgrammaticConfiguration(t *testing.T) {
|
func TestProgrammaticConfiguration(t *testing.T) {
|
||||||
mgr := NewManager()
|
mgr := NewManager()
|
||||||
mgr.Set("server.addr", ":7070")
|
mgr.Set("servers.instances.default.port", 7070)
|
||||||
mgr.Set("tracing.service_name", "test-service")
|
mgr.Set("tracing.service_name", "test-service")
|
||||||
|
|
||||||
cfg, err := mgr.GetConfig()
|
cfg, err := mgr.GetConfig()
|
||||||
@@ -109,8 +126,8 @@ func TestProgrammaticConfiguration(t *testing.T) {
|
|||||||
t.Fatalf("Failed to get config: %v", err)
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Server.Addr != ":7070" {
|
if cfg.Servers.Instances["default"].Port != 7070 {
|
||||||
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
|
t.Errorf("server port: got %d, want 7070", cfg.Servers.Instances["default"].Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Tracing.ServiceName != "test-service" {
|
if cfg.Tracing.ServiceName != "test-service" {
|
||||||
@@ -148,8 +165,8 @@ func TestWithOptions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set environment variable with custom prefix
|
// Set environment variable with custom prefix
|
||||||
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
|
os.Setenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT", "5000")
|
||||||
defer os.Unsetenv("MYAPP_SERVER_ADDR")
|
defer os.Unsetenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT")
|
||||||
|
|
||||||
if err := mgr.Load(); err != nil {
|
if err := mgr.Load(); err != nil {
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
@@ -160,7 +177,432 @@ func TestWithOptions(t *testing.T) {
|
|||||||
t.Fatalf("Failed to get config: %v", err)
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Server.Addr != ":5000" {
|
if cfg.Servers.Instances["default"].Port != 5000 {
|
||||||
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
|
t.Errorf("server port: got %d, want 5000", cfg.Servers.Instances["default"].Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServersConfig(t *testing.T) {
|
||||||
|
mgr := NewManager()
|
||||||
|
if err := mgr.Load(); err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test default server exists
|
||||||
|
if cfg.Servers.DefaultServer != "default" {
|
||||||
|
t.Errorf("Expected default_server to be 'default', got %s", cfg.Servers.DefaultServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test default instance
|
||||||
|
defaultServer, ok := cfg.Servers.Instances["default"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Default server instance not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultServer.Port != 8080 {
|
||||||
|
t.Errorf("Expected default port 8080, got %d", defaultServer.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultServer.Name != "default" {
|
||||||
|
t.Errorf("Expected default name 'default', got %s", defaultServer.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultServer.Description != "Default HTTP server" {
|
||||||
|
t.Errorf("Expected description 'Default HTTP server', got %s", defaultServer.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleServerInstances(t *testing.T) {
|
||||||
|
mgr := NewManager()
|
||||||
|
|
||||||
|
// Add additional server instances (default instance exists from defaults)
|
||||||
|
mgr.Set("servers.default_server", "api")
|
||||||
|
mgr.Set("servers.instances.api.name", "api")
|
||||||
|
mgr.Set("servers.instances.api.host", "0.0.0.0")
|
||||||
|
mgr.Set("servers.instances.api.port", 8080)
|
||||||
|
mgr.Set("servers.instances.admin.name", "admin")
|
||||||
|
mgr.Set("servers.instances.admin.host", "localhost")
|
||||||
|
mgr.Set("servers.instances.admin.port", 8081)
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have default + api + admin = 3 instances
|
||||||
|
if len(cfg.Servers.Instances) < 2 {
|
||||||
|
t.Errorf("Expected at least 2 server instances, got %d", len(cfg.Servers.Instances))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify api instance
|
||||||
|
apiServer, ok := cfg.Servers.Instances["api"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("API server instance not found")
|
||||||
|
}
|
||||||
|
if apiServer.Port != 8080 {
|
||||||
|
t.Errorf("Expected API port 8080, got %d", apiServer.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify admin instance
|
||||||
|
adminServer, ok := cfg.Servers.Instances["admin"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Admin server instance not found")
|
||||||
|
}
|
||||||
|
if adminServer.Port != 8081 {
|
||||||
|
t.Errorf("Expected admin port 8081, got %d", adminServer.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate default server
|
||||||
|
if err := cfg.Servers.Validate(); err != nil {
|
||||||
|
t.Errorf("Server config validation failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get default
|
||||||
|
defaultSrv, err := cfg.Servers.GetDefault()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get default server: %v", err)
|
||||||
|
}
|
||||||
|
if defaultSrv.Name != "api" {
|
||||||
|
t.Errorf("Expected default server 'api', got '%s'", defaultSrv.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtensionsField(t *testing.T) {
|
||||||
|
mgr := NewManager()
|
||||||
|
|
||||||
|
// Set custom extensions
|
||||||
|
mgr.Set("extensions.custom_feature.enabled", true)
|
||||||
|
mgr.Set("extensions.custom_feature.api_key", "test-key")
|
||||||
|
mgr.Set("extensions.another_extension.timeout", "5s")
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Extensions == nil {
|
||||||
|
t.Fatal("Extensions should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify extensions are accessible
|
||||||
|
customFeature := mgr.Get("extensions.custom_feature")
|
||||||
|
if customFeature == nil {
|
||||||
|
t.Error("custom_feature extension not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify via config manager methods
|
||||||
|
if !mgr.GetBool("extensions.custom_feature.enabled") {
|
||||||
|
t.Error("Expected custom_feature.enabled to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgr.GetString("extensions.custom_feature.api_key") != "test-key" {
|
||||||
|
t.Error("Expected api_key to be 'test-key'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerInstanceValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
instance ServerInstanceConfig
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid basic config",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid port - too high",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 99999,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid port - zero",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 0,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty name",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "",
|
||||||
|
Port: 8080,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "conflicting TLS options",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 8080,
|
||||||
|
SelfSignedSSL: true,
|
||||||
|
AutoTLS: true,
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete SSL cert config",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 8080,
|
||||||
|
SSLCert: "/path/to/cert.pem",
|
||||||
|
// Missing SSLKey
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AutoTLS without domains",
|
||||||
|
instance: ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 8080,
|
||||||
|
AutoTLS: true,
|
||||||
|
// Missing AutoTLSDomains
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.instance.Validate()
|
||||||
|
if tt.expectErr && err == nil {
|
||||||
|
t.Error("Expected validation error, got nil")
|
||||||
|
}
|
||||||
|
if !tt.expectErr && err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyGlobalDefaults(t *testing.T) {
|
||||||
|
globals := ServersConfig{
|
||||||
|
ShutdownTimeout: 30 * time.Second,
|
||||||
|
DrainTimeout: 25 * time.Second,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
instance := ServerInstanceConfig{
|
||||||
|
Name: "test",
|
||||||
|
Port: 8080,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply global defaults
|
||||||
|
instance.ApplyGlobalDefaults(globals)
|
||||||
|
|
||||||
|
// Check that defaults were applied
|
||||||
|
if instance.ShutdownTimeout == nil || *instance.ShutdownTimeout != 30*time.Second {
|
||||||
|
t.Error("ShutdownTimeout not applied correctly")
|
||||||
|
}
|
||||||
|
if instance.DrainTimeout == nil || *instance.DrainTimeout != 25*time.Second {
|
||||||
|
t.Error("DrainTimeout not applied correctly")
|
||||||
|
}
|
||||||
|
if instance.ReadTimeout == nil || *instance.ReadTimeout != 10*time.Second {
|
||||||
|
t.Error("ReadTimeout not applied correctly")
|
||||||
|
}
|
||||||
|
if instance.WriteTimeout == nil || *instance.WriteTimeout != 10*time.Second {
|
||||||
|
t.Error("WriteTimeout not applied correctly")
|
||||||
|
}
|
||||||
|
if instance.IdleTimeout == nil || *instance.IdleTimeout != 120*time.Second {
|
||||||
|
t.Error("IdleTimeout not applied correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that explicit overrides are not replaced
|
||||||
|
customTimeout := 60 * time.Second
|
||||||
|
instance2 := ServerInstanceConfig{
|
||||||
|
Name: "test2",
|
||||||
|
Port: 8081,
|
||||||
|
ShutdownTimeout: &customTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
instance2.ApplyGlobalDefaults(globals)
|
||||||
|
|
||||||
|
if instance2.ShutdownTimeout == nil || *instance2.ShutdownTimeout != 60*time.Second {
|
||||||
|
t.Error("Custom ShutdownTimeout was overridden")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathsConfig(t *testing.T) {
|
||||||
|
mgr := NewManager()
|
||||||
|
if err := mgr.Load(); err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test default paths exist
|
||||||
|
if !cfg.Paths.Has("data_dir") {
|
||||||
|
t.Error("Expected data_dir path to exist")
|
||||||
|
}
|
||||||
|
if !cfg.Paths.Has("config_dir") {
|
||||||
|
t.Error("Expected config_dir path to exist")
|
||||||
|
}
|
||||||
|
if !cfg.Paths.Has("logs_dir") {
|
||||||
|
t.Error("Expected logs_dir path to exist")
|
||||||
|
}
|
||||||
|
if !cfg.Paths.Has("temp_dir") {
|
||||||
|
t.Error("Expected temp_dir path to exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Get method
|
||||||
|
dataDir, err := cfg.Paths.Get("data_dir")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get data_dir: %v", err)
|
||||||
|
}
|
||||||
|
if dataDir != "./data" {
|
||||||
|
t.Errorf("Expected data_dir to be './data', got '%s'", dataDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetOrDefault
|
||||||
|
existing := cfg.Paths.GetOrDefault("data_dir", "/default/path")
|
||||||
|
if existing != "./data" {
|
||||||
|
t.Errorf("Expected existing path, got '%s'", existing)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonExisting := cfg.Paths.GetOrDefault("nonexistent", "/default/path")
|
||||||
|
if nonExisting != "/default/path" {
|
||||||
|
t.Errorf("Expected default path, got '%s'", nonExisting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathsConfigMethods(t *testing.T) {
|
||||||
|
pc := PathsConfig{
|
||||||
|
"base": "/var/myapp",
|
||||||
|
"data": "/var/myapp/data",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Get
|
||||||
|
path, err := pc.Get("base")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get path: %v", err)
|
||||||
|
}
|
||||||
|
if path != "/var/myapp" {
|
||||||
|
t.Errorf("Expected '/var/myapp', got '%s'", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Get non-existent
|
||||||
|
_, err = pc.Get("nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Set
|
||||||
|
pc.Set("new_path", "/new/location")
|
||||||
|
newPath, err := pc.Get("new_path")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get newly set path: %v", err)
|
||||||
|
}
|
||||||
|
if newPath != "/new/location" {
|
||||||
|
t.Errorf("Expected '/new/location', got '%s'", newPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Has
|
||||||
|
if !pc.Has("base") {
|
||||||
|
t.Error("Expected 'base' path to exist")
|
||||||
|
}
|
||||||
|
if pc.Has("nonexistent") {
|
||||||
|
t.Error("Expected 'nonexistent' path to not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test List
|
||||||
|
names := pc.List()
|
||||||
|
if len(names) != 3 {
|
||||||
|
t.Errorf("Expected 3 paths, got %d", len(names))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Join
|
||||||
|
joined, err := pc.Join("base", "subdir", "file.txt")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to join paths: %v", err)
|
||||||
|
}
|
||||||
|
expected := "/var/myapp/subdir/file.txt"
|
||||||
|
if joined != expected {
|
||||||
|
t.Errorf("Expected '%s', got '%s'", expected, joined)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathsConfigEnvironmentVariables(t *testing.T) {
|
||||||
|
// Set environment variables for paths
|
||||||
|
os.Setenv("RESOLVESPEC_PATHS_DATA_DIR", "/custom/data")
|
||||||
|
os.Setenv("RESOLVESPEC_PATHS_LOGS_DIR", "/custom/logs")
|
||||||
|
defer func() {
|
||||||
|
os.Unsetenv("RESOLVESPEC_PATHS_DATA_DIR")
|
||||||
|
os.Unsetenv("RESOLVESPEC_PATHS_LOGS_DIR")
|
||||||
|
}()
|
||||||
|
|
||||||
|
mgr := NewManager()
|
||||||
|
if err := mgr.Load(); err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test environment variable override of existing default path
|
||||||
|
dataDir, err := cfg.Paths.Get("data_dir")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get data_dir: %v", err)
|
||||||
|
}
|
||||||
|
if dataDir != "/custom/data" {
|
||||||
|
t.Errorf("Expected '/custom/data', got '%s'", dataDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test another environment variable override
|
||||||
|
logsDir, err := cfg.Paths.Get("logs_dir")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get logs_dir: %v", err)
|
||||||
|
}
|
||||||
|
if logsDir != "/custom/logs" {
|
||||||
|
t.Errorf("Expected '/custom/logs', got '%s'", logsDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathsConfigProgrammatic(t *testing.T) {
|
||||||
|
mgr := NewManager()
|
||||||
|
|
||||||
|
// Set custom paths programmatically
|
||||||
|
mgr.Set("paths.custom_dir", "/my/custom/dir")
|
||||||
|
mgr.Set("paths.cache_dir", "/var/cache/myapp")
|
||||||
|
|
||||||
|
cfg, err := mgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify custom paths
|
||||||
|
customDir, err := cfg.Paths.Get("custom_dir")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get custom_dir: %v", err)
|
||||||
|
}
|
||||||
|
if customDir != "/my/custom/dir" {
|
||||||
|
t.Errorf("Expected '/my/custom/dir', got '%s'", customDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDir, err := cfg.Paths.Get("cache_dir")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to get cache_dir: %v", err)
|
||||||
|
}
|
||||||
|
if cacheDir != "/var/cache/myapp" {
|
||||||
|
t.Errorf("Expected '/var/cache/myapp', got '%s'", cacheDir)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
117
pkg/config/paths.go
Normal file
117
pkg/config/paths.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get retrieves a path by name
|
||||||
|
func (pc PathsConfig) Get(name string) (string, error) {
|
||||||
|
if pc == nil {
|
||||||
|
return "", fmt.Errorf("paths not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
path, ok := pc[name]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path '%s' not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrDefault retrieves a path by name, returning defaultPath if not found
|
||||||
|
func (pc PathsConfig) GetOrDefault(name, defaultPath string) string {
|
||||||
|
if pc == nil {
|
||||||
|
return defaultPath
|
||||||
|
}
|
||||||
|
|
||||||
|
path, ok := pc[name]
|
||||||
|
if !ok {
|
||||||
|
return defaultPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a path by name
|
||||||
|
func (pc PathsConfig) Set(name, path string) {
|
||||||
|
pc[name] = path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has checks if a path exists by name
|
||||||
|
func (pc PathsConfig) Has(name string) bool {
|
||||||
|
if pc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := pc[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureDir ensures a directory exists at the specified path name
|
||||||
|
// Creates the directory if it doesn't exist with the given permissions
|
||||||
|
func (pc PathsConfig) EnsureDir(name string, perm os.FileMode) error {
|
||||||
|
path, err := pc.Get(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if directory exists
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err == nil {
|
||||||
|
// Path exists, check if it's a directory
|
||||||
|
if !info.IsDir() {
|
||||||
|
return fmt.Errorf("path '%s' exists but is not a directory: %s", name, path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Directory doesn't exist, create it
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
if err := os.MkdirAll(path, perm); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory for '%s' at %s: %w", name, path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("failed to stat path '%s' at %s: %w", name, path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AbsPath returns the absolute path for a named path
|
||||||
|
func (pc PathsConfig) AbsPath(name string) (string, error) {
|
||||||
|
path, err := pc.Get(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
absPath, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get absolute path for '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return absPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join joins path segments with a named base path
|
||||||
|
func (pc PathsConfig) Join(name string, elem ...string) (string, error) {
|
||||||
|
base, err := pc.Get(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := append([]string{base}, elem...)
|
||||||
|
return filepath.Join(parts...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all configured path names
|
||||||
|
func (pc PathsConfig) List() []string {
|
||||||
|
if pc == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, 0, len(pc))
|
||||||
|
for name := range pc {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
149
pkg/config/server.go
Normal file
149
pkg/config/server.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplyGlobalDefaults applies global server defaults to this instance
|
||||||
|
// Called for instances that don't specify their own timeout values
|
||||||
|
func (sic *ServerInstanceConfig) ApplyGlobalDefaults(globals ServersConfig) {
|
||||||
|
if sic.ShutdownTimeout == nil && globals.ShutdownTimeout > 0 {
|
||||||
|
t := globals.ShutdownTimeout
|
||||||
|
sic.ShutdownTimeout = &t
|
||||||
|
}
|
||||||
|
if sic.DrainTimeout == nil && globals.DrainTimeout > 0 {
|
||||||
|
t := globals.DrainTimeout
|
||||||
|
sic.DrainTimeout = &t
|
||||||
|
}
|
||||||
|
if sic.ReadTimeout == nil && globals.ReadTimeout > 0 {
|
||||||
|
t := globals.ReadTimeout
|
||||||
|
sic.ReadTimeout = &t
|
||||||
|
}
|
||||||
|
if sic.WriteTimeout == nil && globals.WriteTimeout > 0 {
|
||||||
|
t := globals.WriteTimeout
|
||||||
|
sic.WriteTimeout = &t
|
||||||
|
}
|
||||||
|
if sic.IdleTimeout == nil && globals.IdleTimeout > 0 {
|
||||||
|
t := globals.IdleTimeout
|
||||||
|
sic.IdleTimeout = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the ServerInstanceConfig
|
||||||
|
func (sic *ServerInstanceConfig) Validate() error {
|
||||||
|
if sic.Name == "" {
|
||||||
|
return fmt.Errorf("server instance name cannot be empty")
|
||||||
|
}
|
||||||
|
if sic.Port <= 0 || sic.Port > 65535 {
|
||||||
|
return fmt.Errorf("invalid port: %d (must be 1-65535)", sic.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate TLS options are mutually exclusive
|
||||||
|
tlsCount := 0
|
||||||
|
if sic.SSLCert != "" || sic.SSLKey != "" {
|
||||||
|
tlsCount++
|
||||||
|
}
|
||||||
|
if sic.SelfSignedSSL {
|
||||||
|
tlsCount++
|
||||||
|
}
|
||||||
|
if sic.AutoTLS {
|
||||||
|
tlsCount++
|
||||||
|
}
|
||||||
|
if tlsCount > 1 {
|
||||||
|
return fmt.Errorf("server '%s': only one TLS option can be enabled", sic.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If using certificate files, both must be provided
|
||||||
|
if (sic.SSLCert != "" && sic.SSLKey == "") || (sic.SSLCert == "" && sic.SSLKey != "") {
|
||||||
|
return fmt.Errorf("server '%s': both ssl_cert and ssl_key must be provided", sic.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If using AutoTLS, domains must be specified
|
||||||
|
if sic.AutoTLS && len(sic.AutoTLSDomains) == 0 {
|
||||||
|
return fmt.Errorf("server '%s': auto_tls_domains must be specified when auto_tls is enabled", sic.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the ServersConfig
|
||||||
|
func (sc *ServersConfig) Validate() error {
|
||||||
|
if len(sc.Instances) == 0 {
|
||||||
|
return fmt.Errorf("at least one server instance must be configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc.DefaultServer != "" {
|
||||||
|
if _, ok := sc.Instances[sc.DefaultServer]; !ok {
|
||||||
|
return fmt.Errorf("default server '%s' not found in instances", sc.DefaultServer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each instance
|
||||||
|
for name := range sc.Instances {
|
||||||
|
instance := sc.Instances[name]
|
||||||
|
if instance.Name != name {
|
||||||
|
return fmt.Errorf("server instance name mismatch: key='%s', name='%s'", name, instance.Name)
|
||||||
|
}
|
||||||
|
if err := instance.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefault returns the default server instance configuration
|
||||||
|
func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) {
|
||||||
|
if sc.DefaultServer == "" {
|
||||||
|
return nil, fmt.Errorf("no default server configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
instance, ok := sc.Instances[sc.DefaultServer]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("default server '%s' not found", sc.DefaultServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIPs - GetIP for pc
|
||||||
|
func GetIPs() (hostname string, ipList string, ipNetList []net.IP) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
fmt.Println("Recovered in GetIPs", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
hostname, _ = os.Hostname()
|
||||||
|
ipaddrlist := make([]net.IP, 0)
|
||||||
|
iplist := ""
|
||||||
|
addrs, err := net.LookupIP(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return hostname, iplist, ipaddrlist
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, a := range addrs {
|
||||||
|
// cfg.LogInfo("\nFound IP Host Address: %s", a)
|
||||||
|
if strings.Contains(a.String(), "127.0.0.1") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
iplist = fmt.Sprintf("%s,%s", iplist, a)
|
||||||
|
ipaddrlist = append(ipaddrlist, a)
|
||||||
|
}
|
||||||
|
if iplist == "" {
|
||||||
|
iff, _ := net.InterfaceAddrs()
|
||||||
|
for _, a := range iff {
|
||||||
|
// cfg.LogInfo("\nFound IP Address: %s", a)
|
||||||
|
if strings.Contains(a.String(), "127.0.0.1") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
iplist = fmt.Sprintf("%s,%s", iplist, a)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
iplist = strings.TrimLeft(iplist, ",")
|
||||||
|
return hostname, iplist, ipaddrlist
|
||||||
|
}
|
||||||
531
pkg/dbmanager/README.md
Normal file
531
pkg/dbmanager/README.md
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
# Database Connection Manager (dbmanager)
|
||||||
|
|
||||||
|
A comprehensive database connection manager for Go that provides centralized management of multiple named database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multiple Named Connections**: Manage multiple database connections with names like `primary`, `analytics`, `cache-db`
|
||||||
|
- **Multi-Database Support**: PostgreSQL, SQLite, Microsoft SQL Server, and MongoDB
|
||||||
|
- **Multi-ORM Access**: Each SQL connection provides access through:
|
||||||
|
- **Bun ORM** - Modern, lightweight ORM
|
||||||
|
- **GORM** - Popular Go ORM
|
||||||
|
- **Native** - Standard library `*sql.DB`
|
||||||
|
- All three share the same underlying connection pool
|
||||||
|
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||||
|
- **Production-Ready Features**:
|
||||||
|
- Automatic health checks and reconnection
|
||||||
|
- Prometheus metrics
|
||||||
|
- Connection pooling with configurable limits
|
||||||
|
- Retry logic with exponential backoff
|
||||||
|
- Graceful shutdown
|
||||||
|
- OpenTelemetry tracing support
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/bitechdev/ResolveSpec/pkg/dbmanager
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configuration
|
||||||
|
|
||||||
|
Create a configuration file (e.g., `config.yaml`):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dbmanager:
|
||||||
|
default_connection: "primary"
|
||||||
|
|
||||||
|
# Global connection pool defaults
|
||||||
|
max_open_conns: 25
|
||||||
|
max_idle_conns: 5
|
||||||
|
conn_max_lifetime: 30m
|
||||||
|
conn_max_idle_time: 5m
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 1s
|
||||||
|
retry_max_delay: 10s
|
||||||
|
|
||||||
|
# Health checks
|
||||||
|
health_check_interval: 30s
|
||||||
|
enable_auto_reconnect: true
|
||||||
|
|
||||||
|
connections:
|
||||||
|
# Primary PostgreSQL connection
|
||||||
|
primary:
|
||||||
|
type: postgres
|
||||||
|
host: localhost
|
||||||
|
port: 5432
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: myapp
|
||||||
|
sslmode: disable
|
||||||
|
default_orm: bun
|
||||||
|
enable_metrics: true
|
||||||
|
enable_tracing: true
|
||||||
|
enable_logging: true
|
||||||
|
|
||||||
|
# Read replica for analytics
|
||||||
|
analytics:
|
||||||
|
type: postgres
|
||||||
|
dsn: "postgres://readonly:pass@analytics:5432/analytics"
|
||||||
|
default_orm: bun
|
||||||
|
enable_metrics: true
|
||||||
|
|
||||||
|
# SQLite cache
|
||||||
|
cache-db:
|
||||||
|
type: sqlite
|
||||||
|
filepath: /var/lib/app/cache.db
|
||||||
|
max_open_conns: 1
|
||||||
|
|
||||||
|
# MongoDB for documents
|
||||||
|
documents:
|
||||||
|
type: mongodb
|
||||||
|
host: localhost
|
||||||
|
port: 27017
|
||||||
|
database: documents
|
||||||
|
user: mongouser
|
||||||
|
password: mongopass
|
||||||
|
auth_source: admin
|
||||||
|
enable_metrics: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Initialize Manager
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
cfgMgr := config.NewManager()
|
||||||
|
if err := cfgMgr.Load(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
cfg, _ := cfgMgr.GetConfig()
|
||||||
|
|
||||||
|
// Create database manager
|
||||||
|
mgr, err := dbmanager.NewManager(cfg.DBManager)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
// Connect all databases
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := mgr.Connect(ctx); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Your application code here...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Use Database Connections
|
||||||
|
|
||||||
|
#### Get Default Database
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get the default database (as configured common.Database interface)
|
||||||
|
db, err := mgr.GetDefaultDatabase()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use it with any query
|
||||||
|
var users []User
|
||||||
|
err = db.NewSelect().
|
||||||
|
Model(&users).
|
||||||
|
Where("active = ?", true).
|
||||||
|
Scan(ctx, &users)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Get Named Connection with Specific ORM
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get primary connection
|
||||||
|
primary, err := mgr.Get("primary")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use with Bun
|
||||||
|
bunDB, err := primary.Bun()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
err = bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||||
|
|
||||||
|
// Use with GORM (same underlying connection!)
|
||||||
|
gormDB, err := primary.GORM()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
gormDB.Where("active = ?", true).Find(&users)
|
||||||
|
|
||||||
|
// Use native *sql.DB
|
||||||
|
nativeDB, err := primary.Native()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use MongoDB
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get MongoDB connection
|
||||||
|
docs, err := mgr.Get("documents")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mongoClient, err := docs.MongoDB()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
collection := mongoClient.Database("documents").Collection("articles")
|
||||||
|
// Use MongoDB driver...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Change Default Database
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Switch to analytics database as default
|
||||||
|
err := mgr.SetDefaultDatabase("analytics")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now GetDefaultDatabase() returns the analytics connection
|
||||||
|
db, _ := mgr.GetDefaultDatabase()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### Manager Configuration
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `default_connection` | string | "" | Name of the default connection |
|
||||||
|
| `connections` | map | {} | Map of connection name to ConnectionConfig |
|
||||||
|
| `max_open_conns` | int | 25 | Global default for max open connections |
|
||||||
|
| `max_idle_conns` | int | 5 | Global default for max idle connections |
|
||||||
|
| `conn_max_lifetime` | duration | 30m | Global default for connection max lifetime |
|
||||||
|
| `conn_max_idle_time` | duration | 5m | Global default for connection max idle time |
|
||||||
|
| `retry_attempts` | int | 3 | Number of connection retry attempts |
|
||||||
|
| `retry_delay` | duration | 1s | Initial retry delay |
|
||||||
|
| `retry_max_delay` | duration | 10s | Maximum retry delay |
|
||||||
|
| `health_check_interval` | duration | 30s | Interval between health checks |
|
||||||
|
| `enable_auto_reconnect` | bool | true | Auto-reconnect on health check failure |
|
||||||
|
|
||||||
|
### Connection Configuration
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `name` | string | Unique connection name |
|
||||||
|
| `type` | string | Database type: `postgres`, `sqlite`, `mssql`, `mongodb` |
|
||||||
|
| `dsn` | string | Complete connection string (overrides individual params) |
|
||||||
|
| `host` | string | Database host |
|
||||||
|
| `port` | int | Database port |
|
||||||
|
| `user` | string | Username |
|
||||||
|
| `password` | string | Password |
|
||||||
|
| `database` | string | Database name |
|
||||||
|
| `sslmode` | string | SSL mode (postgres/mssql): `disable`, `require`, etc. |
|
||||||
|
| `schema` | string | Default schema (postgres/mssql) |
|
||||||
|
| `filepath` | string | File path (sqlite only) |
|
||||||
|
| `auth_source` | string | Auth source (mongodb) |
|
||||||
|
| `replica_set` | string | Replica set name (mongodb) |
|
||||||
|
| `read_preference` | string | Read preference (mongodb): `primary`, `secondary`, etc. |
|
||||||
|
| `max_open_conns` | int | Override global max open connections |
|
||||||
|
| `max_idle_conns` | int | Override global max idle connections |
|
||||||
|
| `conn_max_lifetime` | duration | Override global connection max lifetime |
|
||||||
|
| `conn_max_idle_time` | duration | Override global connection max idle time |
|
||||||
|
| `connect_timeout` | duration | Connection timeout (default: 10s) |
|
||||||
|
| `query_timeout` | duration | Query timeout (default: 30s) |
|
||||||
|
| `enable_tracing` | bool | Enable OpenTelemetry tracing |
|
||||||
|
| `enable_metrics` | bool | Enable Prometheus metrics |
|
||||||
|
| `enable_logging` | bool | Enable connection logging |
|
||||||
|
| `default_orm` | string | Default ORM for Database(): `bun`, `gorm`, `native` |
|
||||||
|
| `tags` | map[string]string | Custom tags for filtering/organization |
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Health Checks
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Manual health check
|
||||||
|
if err := mgr.HealthCheck(ctx); err != nil {
|
||||||
|
log.Printf("Health check failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-connection health check
|
||||||
|
primary, _ := mgr.Get("primary")
|
||||||
|
if err := primary.HealthCheck(ctx); err != nil {
|
||||||
|
log.Printf("Primary connection unhealthy: %v", err)
|
||||||
|
|
||||||
|
// Manual reconnect
|
||||||
|
if err := primary.Reconnect(ctx); err != nil {
|
||||||
|
log.Printf("Reconnection failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection Statistics
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get overall statistics
|
||||||
|
stats := mgr.Stats()
|
||||||
|
fmt.Printf("Total connections: %d\n", stats.TotalConnections)
|
||||||
|
fmt.Printf("Healthy: %d, Unhealthy: %d\n", stats.HealthyCount, stats.UnhealthyCount)
|
||||||
|
|
||||||
|
// Per-connection stats
|
||||||
|
for name, connStats := range stats.ConnectionStats {
|
||||||
|
fmt.Printf("%s: %d open, %d in use, %d idle\n",
|
||||||
|
name,
|
||||||
|
connStats.OpenConnections,
|
||||||
|
connStats.InUse,
|
||||||
|
connStats.Idle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Individual connection stats
|
||||||
|
primary, _ := mgr.Get("primary")
|
||||||
|
stats := primary.Stats()
|
||||||
|
fmt.Printf("Wait count: %d, Wait duration: %v\n",
|
||||||
|
stats.WaitCount,
|
||||||
|
stats.WaitDuration)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prometheus Metrics
|
||||||
|
|
||||||
|
The package automatically exports Prometheus metrics:
|
||||||
|
|
||||||
|
- `dbmanager_connections_total` - Total configured connections by type
|
||||||
|
- `dbmanager_connection_status` - Connection health status (1=healthy, 0=unhealthy)
|
||||||
|
- `dbmanager_connection_pool_size` - Connection pool statistics by state
|
||||||
|
- `dbmanager_connection_wait_count` - Times connections waited for availability
|
||||||
|
- `dbmanager_connection_wait_duration_seconds` - Total wait duration
|
||||||
|
- `dbmanager_health_check_duration_seconds` - Health check execution time
|
||||||
|
- `dbmanager_reconnect_attempts_total` - Reconnection attempts and results
|
||||||
|
- `dbmanager_connection_lifetime_closed_total` - Connections closed due to max lifetime
|
||||||
|
- `dbmanager_connection_idle_closed_total` - Connections closed due to max idle time
|
||||||
|
|
||||||
|
Metrics are automatically updated during health checks. To manually publish metrics:
|
||||||
|
|
||||||
|
```go
|
||||||
|
if mgr, ok := mgr.(*connectionManager); ok {
|
||||||
|
mgr.PublishMetrics()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Single Connection Pool, Multiple ORMs
|
||||||
|
|
||||||
|
A key design principle is that Bun, GORM, and Native all wrap the **same underlying `*sql.DB`** connection pool:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ SQL Connection │
|
||||||
|
├─────────────────────────────────────┤
|
||||||
|
│ ┌─────────┐ ┌──────┐ ┌────────┐ │
|
||||||
|
│ │ Bun │ │ GORM │ │ Native │ │
|
||||||
|
│ └────┬────┘ └───┬──┘ └───┬────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ └───────────┴─────────┘ │
|
||||||
|
│ *sql.DB │
|
||||||
|
│ (single pool) │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- No connection duplication
|
||||||
|
- Consistent pool limits across all ORMs
|
||||||
|
- Unified connection statistics
|
||||||
|
- Lower resource usage
|
||||||
|
|
||||||
|
### Provider Pattern
|
||||||
|
|
||||||
|
Each database type has a dedicated provider:
|
||||||
|
|
||||||
|
- **PostgresProvider** - Uses `pgx` driver
|
||||||
|
- **SQLiteProvider** - Uses `glebarez/sqlite` (pure Go)
|
||||||
|
- **MSSQLProvider** - Uses `go-mssqldb`
|
||||||
|
- **MongoProvider** - Uses official `mongo-driver`
|
||||||
|
|
||||||
|
Providers handle:
|
||||||
|
- Connection establishment with retry logic
|
||||||
|
- Health checking
|
||||||
|
- Connection statistics
|
||||||
|
- Connection cleanup
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||||
|
```go
|
||||||
|
primary, _ := mgr.Get("primary") // Good
|
||||||
|
db, _ := mgr.GetDefaultDatabase() // Risky if default changes
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Configure Connection Pools**: Tune based on your workload
|
||||||
|
```yaml
|
||||||
|
connections:
|
||||||
|
primary:
|
||||||
|
max_open_conns: 100 # High traffic API
|
||||||
|
max_idle_conns: 25
|
||||||
|
analytics:
|
||||||
|
max_open_conns: 10 # Background analytics
|
||||||
|
max_idle_conns: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Enable Health Checks**: Catch connection issues early
|
||||||
|
```yaml
|
||||||
|
health_check_interval: 30s
|
||||||
|
enable_auto_reconnect: true
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Use Appropriate ORM**: Choose based on your needs
|
||||||
|
- **Bun**: Modern, fast, type-safe - recommended for new code
|
||||||
|
- **GORM**: Mature, feature-rich - good for existing GORM code
|
||||||
|
- **Native**: Maximum control - use for performance-critical queries
|
||||||
|
|
||||||
|
5. **Monitor Metrics**: Watch connection pool utilization
|
||||||
|
- If `wait_count` is high, increase `max_open_conns`
|
||||||
|
- If `idle` is always high, decrease `max_idle_conns`
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Connection Failures
|
||||||
|
|
||||||
|
If connections fail to establish:
|
||||||
|
|
||||||
|
1. Check configuration:
|
||||||
|
```bash
|
||||||
|
# Test connection manually
|
||||||
|
psql -h localhost -U myuser -d myapp
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Enable logging:
|
||||||
|
```yaml
|
||||||
|
connections:
|
||||||
|
primary:
|
||||||
|
enable_logging: true
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Check retry attempts:
|
||||||
|
```yaml
|
||||||
|
retry_attempts: 5 # Increase retries
|
||||||
|
retry_max_delay: 30s
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pool Exhaustion
|
||||||
|
|
||||||
|
If you see "too many connections" errors:
|
||||||
|
|
||||||
|
1. Increase pool size:
|
||||||
|
```yaml
|
||||||
|
max_open_conns: 50 # Increase from default 25
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Reduce connection lifetime:
|
||||||
|
```yaml
|
||||||
|
conn_max_lifetime: 15m # Recycle faster
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Monitor wait stats:
|
||||||
|
```go
|
||||||
|
stats := primary.Stats()
|
||||||
|
if stats.WaitCount > 1000 {
|
||||||
|
log.Warn("High connection wait count")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### MongoDB vs SQL Confusion
|
||||||
|
|
||||||
|
MongoDB connections don't support SQL ORMs:
|
||||||
|
|
||||||
|
```go
|
||||||
|
docs, _ := mgr.Get("documents")
|
||||||
|
|
||||||
|
// ✓ Correct
|
||||||
|
mongoClient, _ := docs.MongoDB()
|
||||||
|
|
||||||
|
// ✗ Error: ErrNotSQLDatabase
|
||||||
|
bunDB, err := docs.Bun() // Won't work!
|
||||||
|
```
|
||||||
|
|
||||||
|
SQL connections don't support MongoDB:
|
||||||
|
|
||||||
|
```go
|
||||||
|
primary, _ := mgr.Get("primary")
|
||||||
|
|
||||||
|
// ✓ Correct
|
||||||
|
bunDB, _ := primary.Bun()
|
||||||
|
|
||||||
|
// ✗ Error: ErrNotMongoDB
|
||||||
|
mongoClient, err := primary.MongoDB() // Won't work!
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Guide
|
||||||
|
|
||||||
|
### From Raw `database/sql`
|
||||||
|
|
||||||
|
Before:
|
||||||
|
```go
|
||||||
|
db, err := sql.Open("postgres", dsn)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users")
|
||||||
|
```
|
||||||
|
|
||||||
|
After:
|
||||||
|
```go
|
||||||
|
mgr, _ := dbmanager.NewManager(cfg.DBManager)
|
||||||
|
mgr.Connect(ctx)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
primary, _ := mgr.Get("primary")
|
||||||
|
nativeDB, _ := primary.Native()
|
||||||
|
|
||||||
|
rows, err := nativeDB.Query("SELECT * FROM users")
|
||||||
|
```
|
||||||
|
|
||||||
|
### From Direct Bun/GORM
|
||||||
|
|
||||||
|
Before:
|
||||||
|
```go
|
||||||
|
sqldb, _ := sql.Open("pgx", dsn)
|
||||||
|
bunDB := bun.NewDB(sqldb, pgdialect.New())
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
After:
|
||||||
|
```go
|
||||||
|
mgr, _ := dbmanager.NewManager(cfg.DBManager)
|
||||||
|
mgr.Connect(ctx)
|
||||||
|
|
||||||
|
primary, _ := mgr.Get("primary")
|
||||||
|
bunDB, _ := primary.Bun()
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
bunDB.NewSelect().Model(&users).Scan(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Same as the parent project.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please submit issues and pull requests to the main repository.
|
||||||
489
pkg/dbmanager/config.go
Normal file
489
pkg/dbmanager/config.go
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabaseType represents the type of database
|
||||||
|
type DatabaseType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DatabaseTypePostgreSQL represents PostgreSQL database
|
||||||
|
DatabaseTypePostgreSQL DatabaseType = "postgres"
|
||||||
|
|
||||||
|
// DatabaseTypeSQLite represents SQLite database
|
||||||
|
DatabaseTypeSQLite DatabaseType = "sqlite"
|
||||||
|
|
||||||
|
// DatabaseTypeMSSQL represents Microsoft SQL Server database
|
||||||
|
DatabaseTypeMSSQL DatabaseType = "mssql"
|
||||||
|
|
||||||
|
// DatabaseTypeMongoDB represents MongoDB database
|
||||||
|
DatabaseTypeMongoDB DatabaseType = "mongodb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ORMType represents the ORM to use for database operations
|
||||||
|
type ORMType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ORMTypeBun represents Bun ORM
|
||||||
|
ORMTypeBun ORMType = "bun"
|
||||||
|
|
||||||
|
// ORMTypeGORM represents GORM
|
||||||
|
ORMTypeGORM ORMType = "gorm"
|
||||||
|
|
||||||
|
// ORMTypeNative represents native database/sql
|
||||||
|
ORMTypeNative ORMType = "native"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ManagerConfig contains configuration for the database connection manager
|
||||||
|
type ManagerConfig struct {
|
||||||
|
// DefaultConnection is the name of the default connection to use
|
||||||
|
DefaultConnection string `mapstructure:"default_connection"`
|
||||||
|
|
||||||
|
// Connections is a map of connection name to connection configuration
|
||||||
|
Connections map[string]ConnectionConfig `mapstructure:"connections"`
|
||||||
|
|
||||||
|
// Global connection pool defaults
|
||||||
|
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||||
|
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||||
|
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||||
|
ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"`
|
||||||
|
|
||||||
|
// Retry policy
|
||||||
|
RetryAttempts int `mapstructure:"retry_attempts"`
|
||||||
|
RetryDelay time.Duration `mapstructure:"retry_delay"`
|
||||||
|
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||||
|
|
||||||
|
// Health checks
|
||||||
|
HealthCheckInterval time.Duration `mapstructure:"health_check_interval"`
|
||||||
|
EnableAutoReconnect bool `mapstructure:"enable_auto_reconnect"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionConfig defines configuration for a single database connection
|
||||||
|
type ConnectionConfig struct {
|
||||||
|
// Name is the unique name of this connection
|
||||||
|
Name string `mapstructure:"name"`
|
||||||
|
|
||||||
|
// Type is the database type (postgres, sqlite, mssql, mongodb)
|
||||||
|
Type DatabaseType `mapstructure:"type"`
|
||||||
|
|
||||||
|
// DSN is the complete Data Source Name / connection string
|
||||||
|
// If provided, this takes precedence over individual connection parameters
|
||||||
|
DSN string `mapstructure:"dsn"`
|
||||||
|
|
||||||
|
// Connection parameters (used if DSN is not provided)
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
User string `mapstructure:"user"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
Database string `mapstructure:"database"`
|
||||||
|
|
||||||
|
// PostgreSQL/MSSQL specific
|
||||||
|
SSLMode string `mapstructure:"sslmode"` // disable, require, verify-ca, verify-full
|
||||||
|
Schema string `mapstructure:"schema"` // Default schema
|
||||||
|
|
||||||
|
// SQLite specific
|
||||||
|
FilePath string `mapstructure:"filepath"`
|
||||||
|
|
||||||
|
// MongoDB specific
|
||||||
|
AuthSource string `mapstructure:"auth_source"`
|
||||||
|
ReplicaSet string `mapstructure:"replica_set"`
|
||||||
|
ReadPreference string `mapstructure:"read_preference"` // primary, secondary, etc.
|
||||||
|
|
||||||
|
// Connection pool settings (overrides global defaults)
|
||||||
|
MaxOpenConns *int `mapstructure:"max_open_conns"`
|
||||||
|
MaxIdleConns *int `mapstructure:"max_idle_conns"`
|
||||||
|
ConnMaxLifetime *time.Duration `mapstructure:"conn_max_lifetime"`
|
||||||
|
ConnMaxIdleTime *time.Duration `mapstructure:"conn_max_idle_time"`
|
||||||
|
|
||||||
|
// Timeouts
|
||||||
|
ConnectTimeout time.Duration `mapstructure:"connect_timeout"`
|
||||||
|
QueryTimeout time.Duration `mapstructure:"query_timeout"`
|
||||||
|
|
||||||
|
// Features
|
||||||
|
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||||
|
EnableMetrics bool `mapstructure:"enable_metrics"`
|
||||||
|
EnableLogging bool `mapstructure:"enable_logging"`
|
||||||
|
|
||||||
|
// DefaultORM specifies which ORM to use for the Database() method
|
||||||
|
// Options: "bun", "gorm", "native"
|
||||||
|
DefaultORM string `mapstructure:"default_orm"`
|
||||||
|
|
||||||
|
// Tags for organization and filtering
|
||||||
|
Tags map[string]string `mapstructure:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultManagerConfig returns a ManagerConfig with sensible defaults
|
||||||
|
func DefaultManagerConfig() ManagerConfig {
|
||||||
|
return ManagerConfig{
|
||||||
|
DefaultConnection: "",
|
||||||
|
Connections: make(map[string]ConnectionConfig),
|
||||||
|
MaxOpenConns: 25,
|
||||||
|
MaxIdleConns: 5,
|
||||||
|
ConnMaxLifetime: 30 * time.Minute,
|
||||||
|
ConnMaxIdleTime: 5 * time.Minute,
|
||||||
|
RetryAttempts: 3,
|
||||||
|
RetryDelay: 1 * time.Second,
|
||||||
|
RetryMaxDelay: 10 * time.Second,
|
||||||
|
HealthCheckInterval: 15 * time.Second,
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults applies default values to the manager configuration
|
||||||
|
func (c *ManagerConfig) ApplyDefaults() {
|
||||||
|
defaults := DefaultManagerConfig()
|
||||||
|
|
||||||
|
if c.MaxOpenConns == 0 {
|
||||||
|
c.MaxOpenConns = defaults.MaxOpenConns
|
||||||
|
}
|
||||||
|
if c.MaxIdleConns == 0 {
|
||||||
|
c.MaxIdleConns = defaults.MaxIdleConns
|
||||||
|
}
|
||||||
|
if c.ConnMaxLifetime == 0 {
|
||||||
|
c.ConnMaxLifetime = defaults.ConnMaxLifetime
|
||||||
|
}
|
||||||
|
if c.ConnMaxIdleTime == 0 {
|
||||||
|
c.ConnMaxIdleTime = defaults.ConnMaxIdleTime
|
||||||
|
}
|
||||||
|
if c.RetryAttempts == 0 {
|
||||||
|
c.RetryAttempts = defaults.RetryAttempts
|
||||||
|
}
|
||||||
|
if c.RetryDelay == 0 {
|
||||||
|
c.RetryDelay = defaults.RetryDelay
|
||||||
|
}
|
||||||
|
if c.RetryMaxDelay == 0 {
|
||||||
|
c.RetryMaxDelay = defaults.RetryMaxDelay
|
||||||
|
}
|
||||||
|
if c.HealthCheckInterval == 0 {
|
||||||
|
c.HealthCheckInterval = defaults.HealthCheckInterval
|
||||||
|
}
|
||||||
|
// EnableAutoReconnect defaults to true - apply if not explicitly set
|
||||||
|
// Since this is a boolean, we apply the default unconditionally when it's false
|
||||||
|
if !c.EnableAutoReconnect {
|
||||||
|
c.EnableAutoReconnect = defaults.EnableAutoReconnect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the manager configuration
|
||||||
|
func (c *ManagerConfig) Validate() error {
|
||||||
|
if len(c.Connections) == 0 {
|
||||||
|
return NewConfigurationError("connections", fmt.Errorf("at least one connection must be configured"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.DefaultConnection != "" {
|
||||||
|
if _, ok := c.Connections[c.DefaultConnection]; !ok {
|
||||||
|
return NewConfigurationError("default_connection", fmt.Errorf("default connection '%s' not found in connections", c.DefaultConnection))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each connection
|
||||||
|
for name := range c.Connections {
|
||||||
|
conn := c.Connections[name]
|
||||||
|
if err := conn.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("connection '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults applies default values and global settings to the connection configuration
|
||||||
|
func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
|
||||||
|
// Set name if not already set
|
||||||
|
if cc.Name == "" {
|
||||||
|
cc.Name = "unnamed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply global pool settings if not overridden
|
||||||
|
if cc.MaxOpenConns == nil && global != nil {
|
||||||
|
maxOpen := global.MaxOpenConns
|
||||||
|
cc.MaxOpenConns = &maxOpen
|
||||||
|
}
|
||||||
|
if cc.MaxIdleConns == nil && global != nil {
|
||||||
|
maxIdle := global.MaxIdleConns
|
||||||
|
cc.MaxIdleConns = &maxIdle
|
||||||
|
}
|
||||||
|
if cc.ConnMaxLifetime == nil && global != nil {
|
||||||
|
lifetime := global.ConnMaxLifetime
|
||||||
|
cc.ConnMaxLifetime = &lifetime
|
||||||
|
}
|
||||||
|
if cc.ConnMaxIdleTime == nil && global != nil {
|
||||||
|
idleTime := global.ConnMaxIdleTime
|
||||||
|
cc.ConnMaxIdleTime = &idleTime
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default timeouts
|
||||||
|
if cc.ConnectTimeout == 0 {
|
||||||
|
cc.ConnectTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if cc.QueryTimeout == 0 {
|
||||||
|
cc.QueryTimeout = 2 * time.Minute // Default to 2 minutes
|
||||||
|
} else if cc.QueryTimeout < 2*time.Minute {
|
||||||
|
// Enforce minimum of 2 minutes
|
||||||
|
cc.QueryTimeout = 2 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default ORM
|
||||||
|
if cc.DefaultORM == "" {
|
||||||
|
cc.DefaultORM = string(ORMTypeBun)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default PostgreSQL port
|
||||||
|
if cc.Type == DatabaseTypePostgreSQL && cc.Port == 0 && cc.DSN == "" {
|
||||||
|
cc.Port = 5432
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default MSSQL port
|
||||||
|
if cc.Type == DatabaseTypeMSSQL && cc.Port == 0 && cc.DSN == "" {
|
||||||
|
cc.Port = 1433
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default MongoDB port
|
||||||
|
if cc.Type == DatabaseTypeMongoDB && cc.Port == 0 && cc.DSN == "" {
|
||||||
|
cc.Port = 27017
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default MongoDB auth source
|
||||||
|
if cc.Type == DatabaseTypeMongoDB && cc.AuthSource == "" {
|
||||||
|
cc.AuthSource = "admin"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the connection configuration
|
||||||
|
func (cc *ConnectionConfig) Validate() error {
|
||||||
|
// Validate database type
|
||||||
|
switch cc.Type {
|
||||||
|
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
|
||||||
|
// Valid types
|
||||||
|
default:
|
||||||
|
return NewConfigurationError("type", fmt.Errorf("unsupported database type: %s", cc.Type))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that either DSN or connection parameters are provided
|
||||||
|
if cc.DSN == "" {
|
||||||
|
switch cc.Type {
|
||||||
|
case DatabaseTypePostgreSQL, DatabaseTypeMSSQL, DatabaseTypeMongoDB:
|
||||||
|
if cc.Host == "" {
|
||||||
|
return NewConfigurationError("host", fmt.Errorf("host is required when DSN is not provided"))
|
||||||
|
}
|
||||||
|
if cc.Database == "" {
|
||||||
|
return NewConfigurationError("database", fmt.Errorf("database is required when DSN is not provided"))
|
||||||
|
}
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
if cc.FilePath == "" {
|
||||||
|
return NewConfigurationError("filepath", fmt.Errorf("filepath is required for SQLite when DSN is not provided"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ORM type
|
||||||
|
if cc.DefaultORM != "" {
|
||||||
|
switch ORMType(cc.DefaultORM) {
|
||||||
|
case ORMTypeBun, ORMTypeGORM, ORMTypeNative:
|
||||||
|
// Valid ORM types
|
||||||
|
default:
|
||||||
|
return NewConfigurationError("default_orm", fmt.Errorf("unsupported ORM type: %s", cc.DefaultORM))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildDSN builds a connection string from individual parameters
|
||||||
|
func (cc *ConnectionConfig) BuildDSN() (string, error) {
|
||||||
|
// If DSN is already provided, use it
|
||||||
|
if cc.DSN != "" {
|
||||||
|
return cc.DSN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cc.Type {
|
||||||
|
case DatabaseTypePostgreSQL:
|
||||||
|
return cc.buildPostgresDSN(), nil
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
return cc.buildSQLiteDSN(), nil
|
||||||
|
case DatabaseTypeMSSQL:
|
||||||
|
return cc.buildMSSQLDSN(), nil
|
||||||
|
case DatabaseTypeMongoDB:
|
||||||
|
return cc.buildMongoDSN(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("cannot build DSN for database type: %s", cc.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ConnectionConfig) buildPostgresDSN() string {
|
||||||
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s",
|
||||||
|
cc.Host, cc.Port, cc.User, cc.Password, cc.Database)
|
||||||
|
|
||||||
|
if cc.SSLMode != "" {
|
||||||
|
dsn += fmt.Sprintf(" sslmode=%s", cc.SSLMode)
|
||||||
|
} else {
|
||||||
|
dsn += " sslmode=disable"
|
||||||
|
}
|
||||||
|
|
||||||
|
if cc.Schema != "" {
|
||||||
|
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add statement_timeout for query execution timeout (in milliseconds)
|
||||||
|
if cc.QueryTimeout > 0 {
|
||||||
|
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||||
|
dsn += fmt.Sprintf(" statement_timeout=%d", timeoutMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ConnectionConfig) buildSQLiteDSN() string {
|
||||||
|
filepath := cc.FilePath
|
||||||
|
if filepath == "" {
|
||||||
|
filepath = ":memory:"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add query parameters for timeouts
|
||||||
|
// Note: SQLite driver supports _timeout parameter (in milliseconds)
|
||||||
|
if cc.QueryTimeout > 0 {
|
||||||
|
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||||
|
filepath += fmt.Sprintf("?_timeout=%d", timeoutMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||||
|
// Format: sqlserver://username:password@host:port?database=dbname
|
||||||
|
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
|
||||||
|
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
|
||||||
|
|
||||||
|
if cc.Schema != "" {
|
||||||
|
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add connection timeout (in seconds)
|
||||||
|
if cc.ConnectTimeout > 0 {
|
||||||
|
timeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||||
|
dsn += fmt.Sprintf("&connection timeout=%d", timeoutSec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add dial timeout for TCP connection (in seconds)
|
||||||
|
if cc.ConnectTimeout > 0 {
|
||||||
|
dialTimeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||||
|
dsn += fmt.Sprintf("&dial timeout=%d", dialTimeoutSec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add read timeout (in seconds) - enforces timeout for reading data
|
||||||
|
if cc.QueryTimeout > 0 {
|
||||||
|
readTimeoutSec := int(cc.QueryTimeout.Seconds())
|
||||||
|
dsn += fmt.Sprintf("&read timeout=%d", readTimeoutSec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ConnectionConfig) buildMongoDSN() string {
|
||||||
|
// Format: mongodb://username:password@host:port/database?authSource=admin
|
||||||
|
var dsn string
|
||||||
|
|
||||||
|
if cc.User != "" && cc.Password != "" {
|
||||||
|
dsn = fmt.Sprintf("mongodb://%s:%s@%s:%d/%s",
|
||||||
|
cc.User, cc.Password, cc.Host, cc.Port, cc.Database)
|
||||||
|
} else {
|
||||||
|
dsn = fmt.Sprintf("mongodb://%s:%d/%s", cc.Host, cc.Port, cc.Database)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := ""
|
||||||
|
if cc.AuthSource != "" {
|
||||||
|
params += fmt.Sprintf("authSource=%s", cc.AuthSource)
|
||||||
|
}
|
||||||
|
if cc.ReplicaSet != "" {
|
||||||
|
if params != "" {
|
||||||
|
params += "&"
|
||||||
|
}
|
||||||
|
params += fmt.Sprintf("replicaSet=%s", cc.ReplicaSet)
|
||||||
|
}
|
||||||
|
if cc.ReadPreference != "" {
|
||||||
|
if params != "" {
|
||||||
|
params += "&"
|
||||||
|
}
|
||||||
|
params += fmt.Sprintf("readPreference=%s", cc.ReadPreference)
|
||||||
|
}
|
||||||
|
|
||||||
|
if params != "" {
|
||||||
|
dsn += "?" + params
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsn
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromConfig converts config.DBManagerConfig to internal ManagerConfig
|
||||||
|
func FromConfig(cfg config.DBManagerConfig) ManagerConfig {
|
||||||
|
mgr := ManagerConfig{
|
||||||
|
DefaultConnection: cfg.DefaultConnection,
|
||||||
|
Connections: make(map[string]ConnectionConfig),
|
||||||
|
MaxOpenConns: cfg.MaxOpenConns,
|
||||||
|
MaxIdleConns: cfg.MaxIdleConns,
|
||||||
|
ConnMaxLifetime: cfg.ConnMaxLifetime,
|
||||||
|
ConnMaxIdleTime: cfg.ConnMaxIdleTime,
|
||||||
|
RetryAttempts: cfg.RetryAttempts,
|
||||||
|
RetryDelay: cfg.RetryDelay,
|
||||||
|
RetryMaxDelay: cfg.RetryMaxDelay,
|
||||||
|
HealthCheckInterval: cfg.HealthCheckInterval,
|
||||||
|
EnableAutoReconnect: cfg.EnableAutoReconnect,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert connections
|
||||||
|
for name := range cfg.Connections {
|
||||||
|
connCfg := cfg.Connections[name]
|
||||||
|
mgr.Connections[name] = ConnectionConfig{
|
||||||
|
Name: connCfg.Name,
|
||||||
|
Type: DatabaseType(connCfg.Type),
|
||||||
|
DSN: connCfg.DSN,
|
||||||
|
Host: connCfg.Host,
|
||||||
|
Port: connCfg.Port,
|
||||||
|
User: connCfg.User,
|
||||||
|
Password: connCfg.Password,
|
||||||
|
Database: connCfg.Database,
|
||||||
|
SSLMode: connCfg.SSLMode,
|
||||||
|
Schema: connCfg.Schema,
|
||||||
|
FilePath: connCfg.FilePath,
|
||||||
|
AuthSource: connCfg.AuthSource,
|
||||||
|
ReplicaSet: connCfg.ReplicaSet,
|
||||||
|
ReadPreference: connCfg.ReadPreference,
|
||||||
|
MaxOpenConns: connCfg.MaxOpenConns,
|
||||||
|
MaxIdleConns: connCfg.MaxIdleConns,
|
||||||
|
ConnMaxLifetime: connCfg.ConnMaxLifetime,
|
||||||
|
ConnMaxIdleTime: connCfg.ConnMaxIdleTime,
|
||||||
|
ConnectTimeout: connCfg.ConnectTimeout,
|
||||||
|
QueryTimeout: connCfg.QueryTimeout,
|
||||||
|
EnableTracing: connCfg.EnableTracing,
|
||||||
|
EnableMetrics: connCfg.EnableMetrics,
|
||||||
|
EnableLogging: connCfg.EnableLogging,
|
||||||
|
DefaultORM: connCfg.DefaultORM,
|
||||||
|
Tags: connCfg.Tags,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getter methods to implement providers.ConnectionConfig interface
|
||||||
|
func (cc *ConnectionConfig) GetName() string { return cc.Name }
|
||||||
|
func (cc *ConnectionConfig) GetType() string { return string(cc.Type) }
|
||||||
|
func (cc *ConnectionConfig) GetHost() string { return cc.Host }
|
||||||
|
func (cc *ConnectionConfig) GetPort() int { return cc.Port }
|
||||||
|
func (cc *ConnectionConfig) GetUser() string { return cc.User }
|
||||||
|
func (cc *ConnectionConfig) GetPassword() string { return cc.Password }
|
||||||
|
func (cc *ConnectionConfig) GetDatabase() string { return cc.Database }
|
||||||
|
func (cc *ConnectionConfig) GetFilePath() string { return cc.FilePath }
|
||||||
|
func (cc *ConnectionConfig) GetConnectTimeout() time.Duration { return cc.ConnectTimeout }
|
||||||
|
func (cc *ConnectionConfig) GetEnableLogging() bool { return cc.EnableLogging }
|
||||||
|
func (cc *ConnectionConfig) GetMaxOpenConns() *int { return cc.MaxOpenConns }
|
||||||
|
func (cc *ConnectionConfig) GetMaxIdleConns() *int { return cc.MaxIdleConns }
|
||||||
|
func (cc *ConnectionConfig) GetConnMaxLifetime() *time.Duration { return cc.ConnMaxLifetime }
|
||||||
|
func (cc *ConnectionConfig) GetConnMaxIdleTime() *time.Duration { return cc.ConnMaxIdleTime }
|
||||||
|
func (cc *ConnectionConfig) GetQueryTimeout() time.Duration { return cc.QueryTimeout }
|
||||||
|
func (cc *ConnectionConfig) GetEnableMetrics() bool { return cc.EnableMetrics }
|
||||||
|
func (cc *ConnectionConfig) GetReadPreference() string { return cc.ReadPreference }
|
||||||
667
pkg/dbmanager/connection.go
Normal file
667
pkg/dbmanager/connection.go
Normal file
@@ -0,0 +1,667 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"github.com/uptrace/bun/schema"
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connection represents a single named database connection
|
||||||
|
type Connection interface {
|
||||||
|
// Metadata
|
||||||
|
Name() string
|
||||||
|
Type() DatabaseType
|
||||||
|
|
||||||
|
// ORM Access (SQL databases only)
|
||||||
|
Bun() (*bun.DB, error)
|
||||||
|
GORM() (*gorm.DB, error)
|
||||||
|
Native() (*sql.DB, error)
|
||||||
|
|
||||||
|
// Common Database interface (for SQL databases)
|
||||||
|
Database() (common.Database, error)
|
||||||
|
|
||||||
|
// MongoDB Access (MongoDB only)
|
||||||
|
MongoDB() (*mongo.Client, error)
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
Connect(ctx context.Context) error
|
||||||
|
Close() error
|
||||||
|
HealthCheck(ctx context.Context) error
|
||||||
|
Reconnect(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stats
|
||||||
|
Stats() *ConnectionStats
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionStats contains statistics about a database connection
|
||||||
|
type ConnectionStats struct {
|
||||||
|
Name string
|
||||||
|
Type DatabaseType
|
||||||
|
Connected bool
|
||||||
|
LastHealthCheck time.Time
|
||||||
|
HealthCheckStatus string
|
||||||
|
|
||||||
|
// SQL connection pool stats
|
||||||
|
OpenConnections int
|
||||||
|
InUse int
|
||||||
|
Idle int
|
||||||
|
WaitCount int64
|
||||||
|
WaitDuration time.Duration
|
||||||
|
MaxIdleClosed int64
|
||||||
|
MaxLifetimeClosed int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlConnection implements Connection for SQL databases (PostgreSQL, SQLite, MSSQL)
|
||||||
|
type sqlConnection struct {
|
||||||
|
name string
|
||||||
|
dbType DatabaseType
|
||||||
|
config ConnectionConfig
|
||||||
|
provider Provider
|
||||||
|
|
||||||
|
// Lazy-initialized ORM instances (all wrap the same sql.DB)
|
||||||
|
nativeDB *sql.DB
|
||||||
|
bunDB *bun.DB
|
||||||
|
gormDB *gorm.DB
|
||||||
|
|
||||||
|
// Adapters for common.Database interface
|
||||||
|
bunAdapter *database.BunAdapter
|
||||||
|
gormAdapter *database.GormAdapter
|
||||||
|
nativeAdapter common.Database
|
||||||
|
|
||||||
|
// State
|
||||||
|
connected bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Health check
|
||||||
|
lastHealthCheck time.Time
|
||||||
|
healthCheckStatus string
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSQLConnection creates a new SQL connection
|
||||||
|
func newSQLConnection(name string, dbType DatabaseType, config ConnectionConfig, provider Provider) *sqlConnection {
|
||||||
|
return &sqlConnection{
|
||||||
|
name: name,
|
||||||
|
dbType: dbType,
|
||||||
|
config: config,
|
||||||
|
provider: provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the connection name
|
||||||
|
func (c *sqlConnection) Name() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the database type
|
||||||
|
func (c *sqlConnection) Type() DatabaseType {
|
||||||
|
return c.dbType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes the database connection
|
||||||
|
func (c *sqlConnection) Connect(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.connected {
|
||||||
|
return ErrAlreadyConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.provider.Connect(ctx, &c.config); err != nil {
|
||||||
|
return NewConnectionError(c.name, "connect", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connected = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the database connection and all ORM instances
|
||||||
|
func (c *sqlConnection) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close Bun if initialized
|
||||||
|
if c.bunDB != nil {
|
||||||
|
if err := c.bunDB.Close(); err != nil {
|
||||||
|
return NewConnectionError(c.name, "close bun", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GORM doesn't have a separate close - it uses the underlying sql.DB
|
||||||
|
|
||||||
|
// Close the provider (which closes the underlying sql.DB)
|
||||||
|
if err := c.provider.Close(); err != nil {
|
||||||
|
return NewConnectionError(c.name, "close", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connected = false
|
||||||
|
c.nativeDB = nil
|
||||||
|
c.bunDB = nil
|
||||||
|
c.gormDB = nil
|
||||||
|
c.bunAdapter = nil
|
||||||
|
c.gormAdapter = nil
|
||||||
|
c.nativeAdapter = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the connection is alive
|
||||||
|
func (c *sqlConnection) HealthCheck(ctx context.Context) error {
|
||||||
|
if c == nil {
|
||||||
|
return fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.lastHealthCheck = time.Now()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
c.healthCheckStatus = "disconnected"
|
||||||
|
return ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.provider.HealthCheck(ctx); err != nil {
|
||||||
|
c.healthCheckStatus = "unhealthy: " + err.Error()
|
||||||
|
return NewConnectionError(c.name, "health check", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.healthCheckStatus = "healthy"
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconnect closes and re-establishes the connection
|
||||||
|
func (c *sqlConnection) Reconnect(ctx context.Context) error {
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Native returns the native *sql.DB connection
|
||||||
|
func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.nativeDB != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.nativeDB, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if c.nativeDB != nil {
|
||||||
|
return c.nativeDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return nil, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get native connection from provider
|
||||||
|
db, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get native", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.nativeDB = db
|
||||||
|
return c.nativeDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bun returns a Bun ORM instance wrapping the native connection
|
||||||
|
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.bunDB != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.bunDB, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if c.bunDB != nil {
|
||||||
|
return c.bunDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get native connection first
|
||||||
|
native, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get bun", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Bun DB wrapping the same sql.DB
|
||||||
|
dialect := c.getBunDialect()
|
||||||
|
c.bunDB = bun.NewDB(native, dialect)
|
||||||
|
|
||||||
|
return c.bunDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GORM returns a GORM instance wrapping the native connection
|
||||||
|
func (c *sqlConnection) GORM() (*gorm.DB, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.gormDB != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.gormDB, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if c.gormDB != nil {
|
||||||
|
return c.gormDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get native connection first
|
||||||
|
native, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get gorm", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create GORM DB wrapping the same sql.DB
|
||||||
|
dialector := c.getGORMDialector(native)
|
||||||
|
db, err := gorm.Open(dialector, &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "initialize gorm", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.gormDB = db
|
||||||
|
return c.gormDB, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Database returns the common.Database interface using the configured default ORM
|
||||||
|
func (c *sqlConnection) Database() (common.Database, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
defaultORM := c.config.DefaultORM
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
switch ORMType(defaultORM) {
|
||||||
|
case ORMTypeBun:
|
||||||
|
return c.getBunAdapter()
|
||||||
|
case ORMTypeGORM:
|
||||||
|
return c.getGORMAdapter()
|
||||||
|
case ORMTypeNative:
|
||||||
|
return c.getNativeAdapter()
|
||||||
|
default:
|
||||||
|
// Default to Bun
|
||||||
|
return c.getBunAdapter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MongoDB returns an error for SQL connections
|
||||||
|
func (c *sqlConnection) MongoDB() (*mongo.Client, error) {
|
||||||
|
return nil, ErrNotMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection statistics
|
||||||
|
func (c *sqlConnection) Stats() *ConnectionStats {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
stats := &ConnectionStats{
|
||||||
|
Name: c.name,
|
||||||
|
Type: c.dbType,
|
||||||
|
Connected: c.connected,
|
||||||
|
LastHealthCheck: c.lastHealthCheck,
|
||||||
|
HealthCheckStatus: c.healthCheckStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get SQL stats if connected
|
||||||
|
if c.connected && c.provider != nil {
|
||||||
|
if providerStats := c.provider.Stats(); providerStats != nil {
|
||||||
|
stats.OpenConnections = providerStats.OpenConnections
|
||||||
|
stats.InUse = providerStats.InUse
|
||||||
|
stats.Idle = providerStats.Idle
|
||||||
|
stats.WaitCount = providerStats.WaitCount
|
||||||
|
stats.WaitDuration = providerStats.WaitDuration
|
||||||
|
stats.MaxIdleClosed = providerStats.MaxIdleClosed
|
||||||
|
stats.MaxLifetimeClosed = providerStats.MaxLifetimeClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBunAdapter returns or creates the Bun adapter
|
||||||
|
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.bunAdapter != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.bunAdapter, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.bunAdapter != nil {
|
||||||
|
return c.bunAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double-check bunDB exists (while already holding write lock)
|
||||||
|
if c.bunDB == nil {
|
||||||
|
// Get native connection first
|
||||||
|
native, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get bun", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Bun DB wrapping the same sql.DB
|
||||||
|
dialect := c.getBunDialect()
|
||||||
|
c.bunDB = bun.NewDB(native, dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||||
|
return c.bunAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getGORMAdapter returns or creates the GORM adapter
|
||||||
|
func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.gormAdapter != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.gormAdapter, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.gormAdapter != nil {
|
||||||
|
return c.gormAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double-check gormDB exists (while already holding write lock)
|
||||||
|
if c.gormDB == nil {
|
||||||
|
// Get native connection first
|
||||||
|
native, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get gorm", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create GORM DB wrapping the same sql.DB
|
||||||
|
dialector := c.getGORMDialector(native)
|
||||||
|
db, err := gorm.Open(dialector, &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "initialize gorm", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.gormDB = db
|
||||||
|
}
|
||||||
|
|
||||||
|
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||||
|
return c.gormAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNativeAdapter returns or creates the native adapter
|
||||||
|
func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("connection is nil")
|
||||||
|
}
|
||||||
|
c.mu.RLock()
|
||||||
|
if c.nativeAdapter != nil {
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
return c.nativeAdapter, nil
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.nativeAdapter != nil {
|
||||||
|
return c.nativeAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double-check nativeDB exists (while already holding write lock)
|
||||||
|
if c.nativeDB == nil {
|
||||||
|
if !c.connected {
|
||||||
|
return nil, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get native connection from provider
|
||||||
|
db, err := c.provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewConnectionError(c.name, "get native", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.nativeDB = db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a native adapter based on database type
|
||||||
|
switch c.dbType {
|
||||||
|
case DatabaseTypePostgreSQL:
|
||||||
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||||
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||||
|
case DatabaseTypeMSSQL:
|
||||||
|
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||||
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||||
|
default:
|
||||||
|
return nil, ErrUnsupportedDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.nativeAdapter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBunDialect returns the appropriate Bun dialect for the database type
|
||||||
|
func (c *sqlConnection) getBunDialect() schema.Dialect {
|
||||||
|
|
||||||
|
switch c.dbType {
|
||||||
|
case DatabaseTypePostgreSQL:
|
||||||
|
return database.GetPostgresDialect()
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
return database.GetSQLiteDialect()
|
||||||
|
case DatabaseTypeMSSQL:
|
||||||
|
return database.GetMSSQLDialect()
|
||||||
|
default:
|
||||||
|
// Default to PostgreSQL
|
||||||
|
return database.GetPostgresDialect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getGORMDialector returns the appropriate GORM dialector for the database type
|
||||||
|
func (c *sqlConnection) getGORMDialector(db *sql.DB) gorm.Dialector {
|
||||||
|
switch c.dbType {
|
||||||
|
case DatabaseTypePostgreSQL:
|
||||||
|
return database.GetPostgresDialector(db)
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
return database.GetSQLiteDialector(db)
|
||||||
|
case DatabaseTypeMSSQL:
|
||||||
|
return database.GetMSSQLDialector(db)
|
||||||
|
default:
|
||||||
|
// Default to PostgreSQL
|
||||||
|
return database.GetPostgresDialector(db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mongoConnection implements Connection for MongoDB
|
||||||
|
type mongoConnection struct {
|
||||||
|
name string
|
||||||
|
config ConnectionConfig
|
||||||
|
provider Provider
|
||||||
|
|
||||||
|
// MongoDB client
|
||||||
|
client *mongo.Client
|
||||||
|
|
||||||
|
// State
|
||||||
|
connected bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Health check
|
||||||
|
lastHealthCheck time.Time
|
||||||
|
healthCheckStatus string
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMongoConnection creates a new MongoDB connection
|
||||||
|
func newMongoConnection(name string, config ConnectionConfig, provider Provider) *mongoConnection {
|
||||||
|
return &mongoConnection{
|
||||||
|
name: name,
|
||||||
|
config: config,
|
||||||
|
provider: provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the connection name
|
||||||
|
func (c *mongoConnection) Name() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the database type (MongoDB)
|
||||||
|
func (c *mongoConnection) Type() DatabaseType {
|
||||||
|
return DatabaseTypeMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes the MongoDB connection
|
||||||
|
func (c *mongoConnection) Connect(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.connected {
|
||||||
|
return ErrAlreadyConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.provider.Connect(ctx, &c.config); err != nil {
|
||||||
|
return NewConnectionError(c.name, "connect", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the mongo client
|
||||||
|
client, err := c.provider.GetMongo()
|
||||||
|
if err != nil {
|
||||||
|
return NewConnectionError(c.name, "get mongo client", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.client = client
|
||||||
|
c.connected = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the MongoDB connection
|
||||||
|
func (c *mongoConnection) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.provider.Close(); err != nil {
|
||||||
|
return NewConnectionError(c.name, "close", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connected = false
|
||||||
|
c.client = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the MongoDB connection is alive
|
||||||
|
func (c *mongoConnection) HealthCheck(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.lastHealthCheck = time.Now()
|
||||||
|
|
||||||
|
if !c.connected {
|
||||||
|
c.healthCheckStatus = "disconnected"
|
||||||
|
return ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.provider.HealthCheck(ctx); err != nil {
|
||||||
|
c.healthCheckStatus = "unhealthy: " + err.Error()
|
||||||
|
return NewConnectionError(c.name, "health check", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.healthCheckStatus = "healthy"
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconnect closes and re-establishes the MongoDB connection
|
||||||
|
func (c *mongoConnection) Reconnect(ctx context.Context) error {
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MongoDB returns the MongoDB client
|
||||||
|
func (c *mongoConnection) MongoDB() (*mongo.Client, error) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
if !c.connected || c.client == nil {
|
||||||
|
return nil, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bun returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) Bun() (*bun.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// GORM returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) GORM() (*gorm.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// Native returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// Database returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) Database() (common.Database, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection statistics for MongoDB
|
||||||
|
func (c *mongoConnection) Stats() *ConnectionStats {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: c.name,
|
||||||
|
Type: DatabaseTypeMongoDB,
|
||||||
|
Connected: c.connected,
|
||||||
|
LastHealthCheck: c.lastHealthCheck,
|
||||||
|
HealthCheckStatus: c.healthCheckStatus,
|
||||||
|
}
|
||||||
|
}
|
||||||
82
pkg/dbmanager/errors.go
Normal file
82
pkg/dbmanager/errors.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common errors
|
||||||
|
var (
|
||||||
|
// ErrConnectionNotFound is returned when a connection with the given name doesn't exist
|
||||||
|
ErrConnectionNotFound = errors.New("connection not found")
|
||||||
|
|
||||||
|
// ErrInvalidConfiguration is returned when the configuration is invalid
|
||||||
|
ErrInvalidConfiguration = errors.New("invalid configuration")
|
||||||
|
|
||||||
|
// ErrConnectionClosed is returned when attempting to use a closed connection
|
||||||
|
ErrConnectionClosed = errors.New("connection is closed")
|
||||||
|
|
||||||
|
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||||
|
ErrNotSQLDatabase = errors.New("not a SQL database")
|
||||||
|
|
||||||
|
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
|
||||||
|
ErrNotMongoDB = errors.New("not a MongoDB connection")
|
||||||
|
|
||||||
|
// ErrUnsupportedDatabase is returned when the database type is not supported
|
||||||
|
ErrUnsupportedDatabase = errors.New("unsupported database type")
|
||||||
|
|
||||||
|
// ErrNoDefaultConnection is returned when no default connection is configured
|
||||||
|
ErrNoDefaultConnection = errors.New("no default connection configured")
|
||||||
|
|
||||||
|
// ErrAlreadyConnected is returned when attempting to connect an already connected connection
|
||||||
|
ErrAlreadyConnected = errors.New("already connected")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionError wraps errors that occur during connection operations
|
||||||
|
type ConnectionError struct {
|
||||||
|
Name string
|
||||||
|
Operation string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionError) Error() string {
|
||||||
|
return fmt.Sprintf("connection '%s' %s: %v", e.Name, e.Operation, e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnectionError creates a new ConnectionError
|
||||||
|
func NewConnectionError(name, operation string, err error) *ConnectionError {
|
||||||
|
return &ConnectionError{
|
||||||
|
Name: name,
|
||||||
|
Operation: operation,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigurationError wraps configuration-related errors
|
||||||
|
type ConfigurationError struct {
|
||||||
|
Field string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConfigurationError) Error() string {
|
||||||
|
if e.Field != "" {
|
||||||
|
return fmt.Sprintf("configuration error in field '%s': %v", e.Field, e.Err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("configuration error: %v", e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConfigurationError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigurationError creates a new ConfigurationError
|
||||||
|
func NewConfigurationError(field string, err error) *ConfigurationError {
|
||||||
|
return &ConfigurationError{
|
||||||
|
Field: field,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
67
pkg/dbmanager/factory.go
Normal file
67
pkg/dbmanager/factory.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createConnection creates a database connection based on the configuration
|
||||||
|
func createConnection(cfg ConnectionConfig) (Connection, error) {
|
||||||
|
// Validate configuration
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid connection configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create provider based on database type
|
||||||
|
provider, err := createProvider(cfg.Type)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create connection wrapper based on database type
|
||||||
|
switch cfg.Type {
|
||||||
|
case DatabaseTypePostgreSQL, DatabaseTypeSQLite, DatabaseTypeMSSQL:
|
||||||
|
return newSQLConnection(cfg.Name, cfg.Type, cfg, provider), nil
|
||||||
|
case DatabaseTypeMongoDB:
|
||||||
|
return newMongoConnection(cfg.Name, cfg, provider), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, cfg.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createProvider creates a database provider based on the database type
|
||||||
|
func createProvider(dbType DatabaseType) (Provider, error) {
|
||||||
|
switch dbType {
|
||||||
|
case DatabaseTypePostgreSQL:
|
||||||
|
return providers.NewPostgresProvider(), nil
|
||||||
|
case DatabaseTypeSQLite:
|
||||||
|
return providers.NewSQLiteProvider(), nil
|
||||||
|
case DatabaseTypeMSSQL:
|
||||||
|
return providers.NewMSSQLProvider(), nil
|
||||||
|
case DatabaseTypeMongoDB:
|
||||||
|
return providers.NewMongoProvider(), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrUnsupportedDatabase, dbType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider is an alias to the providers.Provider interface
|
||||||
|
// This allows dbmanager package consumers to use Provider without importing providers
|
||||||
|
type Provider = providers.Provider
|
||||||
|
|
||||||
|
// NewConnectionFromDB creates a new Connection from an existing *sql.DB
|
||||||
|
// This allows you to use dbmanager features (ORM wrappers, health checks, etc.)
|
||||||
|
// with a database connection that was opened outside of dbmanager
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - name: A unique name for this connection
|
||||||
|
// - dbType: The database type (DatabaseTypePostgreSQL, DatabaseTypeSQLite, or DatabaseTypeMSSQL)
|
||||||
|
// - db: An existing *sql.DB connection
|
||||||
|
//
|
||||||
|
// Returns a Connection that wraps the existing *sql.DB
|
||||||
|
func NewConnectionFromDB(name string, dbType DatabaseType, db *sql.DB) Connection {
|
||||||
|
provider := providers.NewExistingDBProvider(db, name)
|
||||||
|
return newSQLConnection(name, dbType, ConnectionConfig{Name: name, Type: dbType}, provider)
|
||||||
|
}
|
||||||
210
pkg/dbmanager/factory_test.go
Normal file
210
pkg/dbmanager/factory_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB(t *testing.T) {
|
||||||
|
// Open a SQLite in-memory database
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create a connection from the existing database
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
if conn == nil {
|
||||||
|
t.Fatal("Expected connection to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify connection properties
|
||||||
|
if conn.Name() != "test-connection" {
|
||||||
|
t.Errorf("Expected name 'test-connection', got '%s'", conn.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.Type() != DatabaseTypeSQLite {
|
||||||
|
t.Errorf("Expected type DatabaseTypeSQLite, got '%s'", conn.Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_Connect(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Connect should verify the existing connection works
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_Native(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Get native DB
|
||||||
|
nativeDB, err := conn.Native()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Native to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nativeDB != db {
|
||||||
|
t.Error("Expected Native to return the same database instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_Bun(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Get Bun ORM
|
||||||
|
bunDB, err := conn.Bun()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Bun to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bunDB == nil {
|
||||||
|
t.Error("Expected Bun to return a non-nil instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_GORM(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Get GORM
|
||||||
|
gormDB, err := conn.GORM()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected GORM to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gormDB == nil {
|
||||||
|
t.Error("Expected GORM to return a non-nil instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_HealthCheck(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Health check should succeed
|
||||||
|
err = conn.HealthCheck(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_Stats(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = conn.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
stats := conn.Stats()
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("Expected stats to be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.Name != "test-connection" {
|
||||||
|
t.Errorf("Expected stats.Name to be 'test-connection', got '%s'", stats.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.Type != DatabaseTypeSQLite {
|
||||||
|
t.Errorf("Expected stats.Type to be DatabaseTypeSQLite, got '%s'", stats.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stats.Connected {
|
||||||
|
t.Error("Expected stats.Connected to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||||
|
// This test just verifies the factory works with PostgreSQL type
|
||||||
|
// It won't actually connect since we're using SQLite
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
conn := NewConnectionFromDB("test-pg", DatabaseTypePostgreSQL, db)
|
||||||
|
if conn == nil {
|
||||||
|
t.Fatal("Expected connection to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.Type() != DatabaseTypePostgreSQL {
|
||||||
|
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
380
pkg/dbmanager/manager.go
Normal file
380
pkg/dbmanager/manager.go
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages multiple named database connections
|
||||||
|
type Manager interface {
|
||||||
|
// Connection retrieval
|
||||||
|
Get(name string) (Connection, error)
|
||||||
|
GetDefault() (Connection, error)
|
||||||
|
GetAll() map[string]Connection
|
||||||
|
|
||||||
|
// Default database management
|
||||||
|
GetDefaultDatabase() (common.Database, error)
|
||||||
|
SetDefaultDatabase(name string) error
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
Connect(ctx context.Context) error
|
||||||
|
Close() error
|
||||||
|
HealthCheck(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stats
|
||||||
|
Stats() *ManagerStats
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagerStats contains statistics about the connection manager
|
||||||
|
type ManagerStats struct {
|
||||||
|
TotalConnections int
|
||||||
|
HealthyCount int
|
||||||
|
UnhealthyCount int
|
||||||
|
ConnectionStats map[string]*ConnectionStats
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionManager implements Manager
|
||||||
|
type connectionManager struct {
|
||||||
|
connections map[string]Connection
|
||||||
|
config ManagerConfig
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Background health check
|
||||||
|
healthTicker *time.Ticker
|
||||||
|
stopChan chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// singleton instance of the manager
|
||||||
|
instance Manager
|
||||||
|
// instanceMu protects the singleton instance
|
||||||
|
instanceMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupManager initializes the singleton database manager with the provided configuration.
|
||||||
|
// This function must be called before GetInstance().
|
||||||
|
// Returns an error if the manager is already initialized or if configuration is invalid.
|
||||||
|
func SetupManager(cfg ManagerConfig) error {
|
||||||
|
instanceMu.Lock()
|
||||||
|
defer instanceMu.Unlock()
|
||||||
|
|
||||||
|
if instance != nil {
|
||||||
|
return fmt.Errorf("manager already initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
instance = mgr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInstance returns the singleton instance of the database manager.
|
||||||
|
// Returns an error if SetupManager has not been called yet.
|
||||||
|
func GetInstance() (Manager, error) {
|
||||||
|
instanceMu.RLock()
|
||||||
|
defer instanceMu.RUnlock()
|
||||||
|
|
||||||
|
if instance == nil {
|
||||||
|
return nil, fmt.Errorf("manager not initialized: call SetupManager first")
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetInstance resets the singleton instance (primarily for testing purposes).
|
||||||
|
// WARNING: This should only be used in tests. Calling this in production code
|
||||||
|
// while the manager is in use can lead to undefined behavior.
|
||||||
|
func ResetInstance() {
|
||||||
|
instanceMu.Lock()
|
||||||
|
defer instanceMu.Unlock()
|
||||||
|
|
||||||
|
if instance != nil {
|
||||||
|
_ = instance.Close()
|
||||||
|
}
|
||||||
|
instance = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new database connection manager
|
||||||
|
func NewManager(cfg ManagerConfig) (Manager, error) {
|
||||||
|
// Apply defaults and validate configuration
|
||||||
|
cfg.ApplyDefaults()
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &connectionManager{
|
||||||
|
connections: make(map[string]Connection),
|
||||||
|
config: cfg,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a named connection
|
||||||
|
func (m *connectionManager) Get(name string) (Connection, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
conn, ok := m.connections[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefault retrieves the default connection
|
||||||
|
func (m *connectionManager) GetDefault() (Connection, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defaultName := m.config.DefaultConnection
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if defaultName == "" {
|
||||||
|
return nil, ErrNoDefaultConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Get(defaultName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all connections
|
||||||
|
func (m *connectionManager) GetAll() map[string]Connection {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create a copy to avoid concurrent access issues
|
||||||
|
result := make(map[string]Connection, len(m.connections))
|
||||||
|
for name, conn := range m.connections {
|
||||||
|
result[name] = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultDatabase returns the common.Database interface from the default connection
|
||||||
|
func (m *connectionManager) GetDefaultDatabase() (common.Database, error) {
|
||||||
|
conn, err := m.GetDefault()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get database from default connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultDatabase sets the default database connection by name
|
||||||
|
func (m *connectionManager) SetDefaultDatabase(name string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Verify the connection exists
|
||||||
|
if _, ok := m.connections[name]; !ok {
|
||||||
|
return fmt.Errorf("%w: %s", ErrConnectionNotFound, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.config.DefaultConnection = name
|
||||||
|
logger.Info("Default database connection changed: name=%s", name)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes all configured database connections
|
||||||
|
func (m *connectionManager) Connect(ctx context.Context) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Create connections from configuration
|
||||||
|
for name := range m.config.Connections {
|
||||||
|
// Get a copy of the connection config
|
||||||
|
connCfg := m.config.Connections[name]
|
||||||
|
// Apply global defaults to connection config
|
||||||
|
connCfg.ApplyDefaults(&m.config)
|
||||||
|
connCfg.Name = name
|
||||||
|
|
||||||
|
// Create connection using factory
|
||||||
|
conn, err := createConnection(connCfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create connection '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to connect '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.connections[name] = conn
|
||||||
|
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always start background health checks
|
||||||
|
if m.config.HealthCheckInterval > 0 {
|
||||||
|
m.startHealthChecker()
|
||||||
|
logger.Info("Background health checker started: interval=%v", m.config.HealthCheckInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Database manager initialized: connections=%d", len(m.connections))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes all database connections
|
||||||
|
func (m *connectionManager) Close() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Stop health checker
|
||||||
|
m.stopHealthChecker()
|
||||||
|
|
||||||
|
// Close all connections
|
||||||
|
var errors []error
|
||||||
|
for name, conn := range m.connections {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("failed to close connection '%s': %w", name, err))
|
||||||
|
logger.Error("Failed to close connection", "name", name, "error", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Connection closed: name=%s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.connections = make(map[string]Connection)
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("errors closing connections: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Database manager closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck performs health checks on all connections
|
||||||
|
func (m *connectionManager) HealthCheck(ctx context.Context) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
connections := make(map[string]Connection, len(m.connections))
|
||||||
|
for name, conn := range m.connections {
|
||||||
|
connections[name] = conn
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
for name, conn := range connections {
|
||||||
|
if err := conn.HealthCheck(ctx); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("connection '%s': %w", name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return fmt.Errorf("health check failed for %d connections: %v", len(errors), errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics for all connections
|
||||||
|
func (m *connectionManager) Stats() *ManagerStats {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
stats := &ManagerStats{
|
||||||
|
TotalConnections: len(m.connections),
|
||||||
|
ConnectionStats: make(map[string]*ConnectionStats),
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, conn := range m.connections {
|
||||||
|
connStats := conn.Stats()
|
||||||
|
stats.ConnectionStats[name] = connStats
|
||||||
|
|
||||||
|
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
|
||||||
|
stats.HealthyCount++
|
||||||
|
} else {
|
||||||
|
stats.UnhealthyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// startHealthChecker starts background health checking
|
||||||
|
func (m *connectionManager) startHealthChecker() {
|
||||||
|
if m.healthTicker != nil {
|
||||||
|
return // Already running
|
||||||
|
}
|
||||||
|
|
||||||
|
m.healthTicker = time.NewTicker(m.config.HealthCheckInterval)
|
||||||
|
|
||||||
|
m.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer m.wg.Done()
|
||||||
|
logger.Info("Health checker started: interval=%v", m.config.HealthCheckInterval)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.healthTicker.C:
|
||||||
|
m.performHealthCheck()
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Info("Health checker stopped")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopHealthChecker stops background health checking
|
||||||
|
func (m *connectionManager) stopHealthChecker() {
|
||||||
|
if m.healthTicker != nil {
|
||||||
|
m.healthTicker.Stop()
|
||||||
|
close(m.stopChan)
|
||||||
|
m.wg.Wait()
|
||||||
|
m.healthTicker = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// performHealthCheck performs a health check on all connections
|
||||||
|
func (m *connectionManager) performHealthCheck() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
connections := make([]struct {
|
||||||
|
name string
|
||||||
|
conn Connection
|
||||||
|
}, 0, len(m.connections))
|
||||||
|
for name, conn := range m.connections {
|
||||||
|
connections = append(connections, struct {
|
||||||
|
name string
|
||||||
|
conn Connection
|
||||||
|
}{name, conn})
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, item := range connections {
|
||||||
|
if err := item.conn.HealthCheck(ctx); err != nil {
|
||||||
|
logger.Warn("Health check failed",
|
||||||
|
"connection", item.name,
|
||||||
|
"error", err)
|
||||||
|
|
||||||
|
// Attempt reconnection if enabled
|
||||||
|
if m.config.EnableAutoReconnect {
|
||||||
|
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||||
|
if err := item.conn.Reconnect(ctx); err != nil {
|
||||||
|
logger.Error("Reconnection failed",
|
||||||
|
"connection", item.name,
|
||||||
|
"error", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
226
pkg/dbmanager/manager_test.go
Normal file
226
pkg/dbmanager/manager_test.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackgroundHealthChecker(t *testing.T) {
|
||||||
|
// Create a SQLite in-memory database
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create manager config with a short health check interval for testing
|
||||||
|
cfg := ManagerConfig{
|
||||||
|
DefaultConnection: "test",
|
||||||
|
Connections: map[string]ConnectionConfig{
|
||||||
|
"test": {
|
||||||
|
Name: "test",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckInterval: 1 * time.Second, // Short interval for testing
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create manager
|
||||||
|
mgr, err := NewManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect - this should start the background health checker
|
||||||
|
ctx := context.Background()
|
||||||
|
err = mgr.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
// Get the connection to verify it's healthy
|
||||||
|
conn, err := mgr.Get("test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial health check
|
||||||
|
err = conn.HealthCheck(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Initial health check failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for a few health check cycles
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
// Get stats to verify the connection is still healthy
|
||||||
|
stats := conn.Stats()
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("Expected stats to be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stats.Connected {
|
||||||
|
t.Error("Expected connection to still be connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.HealthCheckStatus == "" {
|
||||||
|
t.Error("Expected health check status to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the manager has started the health checker
|
||||||
|
if cm, ok := mgr.(*connectionManager); ok {
|
||||||
|
if cm.healthTicker == nil {
|
||||||
|
t.Error("Expected health ticker to be running")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultHealthCheckInterval(t *testing.T) {
|
||||||
|
// Verify the default health check interval is 15 seconds
|
||||||
|
defaults := DefaultManagerConfig()
|
||||||
|
|
||||||
|
expectedInterval := 15 * time.Second
|
||||||
|
if defaults.HealthCheckInterval != expectedInterval {
|
||||||
|
t.Errorf("Expected default health check interval to be %v, got %v",
|
||||||
|
expectedInterval, defaults.HealthCheckInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !defaults.EnableAutoReconnect {
|
||||||
|
t.Error("Expected EnableAutoReconnect to be true by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyDefaultsEnablesAutoReconnect(t *testing.T) {
|
||||||
|
// Create a config without setting EnableAutoReconnect
|
||||||
|
cfg := ManagerConfig{
|
||||||
|
Connections: map[string]ConnectionConfig{
|
||||||
|
"test": {
|
||||||
|
Name: "test",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's false initially (Go's zero value for bool)
|
||||||
|
if cfg.EnableAutoReconnect {
|
||||||
|
t.Error("Expected EnableAutoReconnect to be false before ApplyDefaults")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
cfg.ApplyDefaults()
|
||||||
|
|
||||||
|
// Verify it's now true
|
||||||
|
if !cfg.EnableAutoReconnect {
|
||||||
|
t.Error("Expected EnableAutoReconnect to be true after ApplyDefaults")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify health check interval is also set
|
||||||
|
if cfg.HealthCheckInterval != 15*time.Second {
|
||||||
|
t.Errorf("Expected health check interval to be 15s, got %v", cfg.HealthCheckInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerHealthCheck(t *testing.T) {
|
||||||
|
// Create a SQLite in-memory database
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create manager config
|
||||||
|
cfg := ManagerConfig{
|
||||||
|
DefaultConnection: "test",
|
||||||
|
Connections: map[string]ConnectionConfig{
|
||||||
|
"test": {
|
||||||
|
Name: "test",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckInterval: 15 * time.Second,
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and connect manager
|
||||||
|
mgr, err := NewManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = mgr.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
// Perform health check on all connections
|
||||||
|
err = mgr.HealthCheck(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Health check failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get stats
|
||||||
|
stats := mgr.Stats()
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("Expected stats to be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.TotalConnections != 1 {
|
||||||
|
t.Errorf("Expected 1 total connection, got %d", stats.TotalConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.HealthyCount != 1 {
|
||||||
|
t.Errorf("Expected 1 healthy connection, got %d", stats.HealthyCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.UnhealthyCount != 0 {
|
||||||
|
t.Errorf("Expected 0 unhealthy connections, got %d", stats.UnhealthyCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerStatsAfterClose(t *testing.T) {
|
||||||
|
cfg := ManagerConfig{
|
||||||
|
DefaultConnection: "test",
|
||||||
|
Connections: map[string]ConnectionConfig{
|
||||||
|
"test": {
|
||||||
|
Name: "test",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HealthCheckInterval: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = mgr.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the manager
|
||||||
|
err = mgr.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to close manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats should show no connections
|
||||||
|
stats := mgr.Stats()
|
||||||
|
if stats.TotalConnections != 0 {
|
||||||
|
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
pkg/dbmanager/metrics.go
Normal file
136
pkg/dbmanager/metrics.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package dbmanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// connectionsTotal tracks the total number of configured database connections
|
||||||
|
connectionsTotal = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connections_total",
|
||||||
|
Help: "Total number of configured database connections",
|
||||||
|
},
|
||||||
|
[]string{"type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionStatus tracks connection health status (1=healthy, 0=unhealthy)
|
||||||
|
connectionStatus = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_status",
|
||||||
|
Help: "Connection status (1=healthy, 0=unhealthy)",
|
||||||
|
},
|
||||||
|
[]string{"name", "type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionPoolSize tracks connection pool sizes
|
||||||
|
connectionPoolSize = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_pool_size",
|
||||||
|
Help: "Current connection pool size",
|
||||||
|
},
|
||||||
|
[]string{"name", "type", "state"}, // state: open, idle, in_use
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionWaitCount tracks how many times connections had to wait for availability
|
||||||
|
connectionWaitCount = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_wait_count",
|
||||||
|
Help: "Number of times connections had to wait for availability",
|
||||||
|
},
|
||||||
|
[]string{"name", "type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionWaitDuration tracks total time connections spent waiting
|
||||||
|
connectionWaitDuration = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_wait_duration_seconds",
|
||||||
|
Help: "Total time connections spent waiting for availability",
|
||||||
|
},
|
||||||
|
[]string{"name", "type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// reconnectAttempts tracks reconnection attempts and their outcomes
|
||||||
|
reconnectAttempts = promauto.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "dbmanager_reconnect_attempts_total",
|
||||||
|
Help: "Total number of reconnection attempts",
|
||||||
|
},
|
||||||
|
[]string{"name", "type", "result"}, // result: success, failure
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionLifetimeClosed tracks connections closed due to max lifetime
|
||||||
|
connectionLifetimeClosed = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_lifetime_closed_total",
|
||||||
|
Help: "Total connections closed due to exceeding max lifetime",
|
||||||
|
},
|
||||||
|
[]string{"name", "type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// connectionIdleClosed tracks connections closed due to max idle time
|
||||||
|
connectionIdleClosed = promauto.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "dbmanager_connection_idle_closed_total",
|
||||||
|
Help: "Total connections closed due to exceeding max idle time",
|
||||||
|
},
|
||||||
|
[]string{"name", "type"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
// PublishMetrics publishes current metrics for all connections
|
||||||
|
func (m *connectionManager) PublishMetrics() {
|
||||||
|
stats := m.Stats()
|
||||||
|
|
||||||
|
// Count connections by type
|
||||||
|
typeCount := make(map[DatabaseType]int)
|
||||||
|
for _, connStats := range stats.ConnectionStats {
|
||||||
|
typeCount[connStats.Type]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update total connections gauge
|
||||||
|
for dbType, count := range typeCount {
|
||||||
|
connectionsTotal.WithLabelValues(string(dbType)).Set(float64(count))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update per-connection metrics
|
||||||
|
for name, connStats := range stats.ConnectionStats {
|
||||||
|
labels := prometheus.Labels{
|
||||||
|
"name": name,
|
||||||
|
"type": string(connStats.Type),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection status
|
||||||
|
status := float64(0)
|
||||||
|
if connStats.Connected && connStats.HealthCheckStatus == "healthy" {
|
||||||
|
status = 1
|
||||||
|
}
|
||||||
|
connectionStatus.With(labels).Set(status)
|
||||||
|
|
||||||
|
// Pool size metrics (SQL databases only)
|
||||||
|
if connStats.Type != DatabaseTypeMongoDB {
|
||||||
|
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "open").Set(float64(connStats.OpenConnections))
|
||||||
|
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "idle").Set(float64(connStats.Idle))
|
||||||
|
connectionPoolSize.WithLabelValues(name, string(connStats.Type), "in_use").Set(float64(connStats.InUse))
|
||||||
|
|
||||||
|
// Wait stats
|
||||||
|
connectionWaitCount.With(labels).Set(float64(connStats.WaitCount))
|
||||||
|
connectionWaitDuration.With(labels).Set(connStats.WaitDuration.Seconds())
|
||||||
|
|
||||||
|
// Lifetime/idle closed
|
||||||
|
connectionLifetimeClosed.With(labels).Set(float64(connStats.MaxLifetimeClosed))
|
||||||
|
connectionIdleClosed.With(labels).Set(float64(connStats.MaxIdleClosed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordReconnectAttempt records a reconnection attempt
|
||||||
|
func RecordReconnectAttempt(name string, dbType DatabaseType, success bool) {
|
||||||
|
result := "failure"
|
||||||
|
if success {
|
||||||
|
result = "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
reconnectAttempts.WithLabelValues(name, string(dbType), result).Inc()
|
||||||
|
}
|
||||||
319
pkg/dbmanager/providers/POSTGRES_NOTIFY_LISTEN.md
Normal file
319
pkg/dbmanager/providers/POSTGRES_NOTIFY_LISTEN.md
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
# PostgreSQL NOTIFY/LISTEN Support
|
||||||
|
|
||||||
|
The `dbmanager` package provides built-in support for PostgreSQL's NOTIFY/LISTEN functionality through the `PostgresListener` type.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
PostgreSQL NOTIFY/LISTEN is a simple pub/sub mechanism that allows database clients to:
|
||||||
|
- **LISTEN** on named channels to receive notifications
|
||||||
|
- **NOTIFY** channels to send messages to all listeners
|
||||||
|
- Receive asynchronous notifications without polling
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✅ Subscribe to multiple channels simultaneously
|
||||||
|
- ✅ Callback-based notification handling
|
||||||
|
- ✅ Automatic reconnection on connection loss
|
||||||
|
- ✅ Automatic resubscription after reconnection
|
||||||
|
- ✅ Thread-safe operations
|
||||||
|
- ✅ Panic recovery in notification handlers
|
||||||
|
- ✅ Dedicated connection for listening (doesn't interfere with queries)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Create PostgreSQL provider
|
||||||
|
cfg := &providers.Config{
|
||||||
|
Name: "primary",
|
||||||
|
Type: "postgres",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "password",
|
||||||
|
Database: "myapp",
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := providers.NewPostgresProvider()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer provider.Close()
|
||||||
|
|
||||||
|
// Get listener
|
||||||
|
listener, err := provider.GetListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to a channel
|
||||||
|
err = listener.Listen("events", func(channel, payload string) {
|
||||||
|
fmt.Printf("Received on %s: %s\n", channel, payload)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a notification
|
||||||
|
err = listener.Notify(ctx, "events", "Hello, World!")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the program running
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Channels
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener, _ := provider.GetListener(ctx)
|
||||||
|
|
||||||
|
// Listen to different channels with different handlers
|
||||||
|
listener.Listen("user_events", func(channel, payload string) {
|
||||||
|
fmt.Printf("User event: %s\n", payload)
|
||||||
|
})
|
||||||
|
|
||||||
|
listener.Listen("order_events", func(channel, payload string) {
|
||||||
|
fmt.Printf("Order event: %s\n", payload)
|
||||||
|
})
|
||||||
|
|
||||||
|
listener.Listen("payment_events", func(channel, payload string) {
|
||||||
|
fmt.Printf("Payment event: %s\n", payload)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Unsubscribing
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Stop listening to a specific channel
|
||||||
|
err := listener.Unlisten("user_events")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to unlisten: %v\n", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Checking Active Channels
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get list of channels currently being listened to
|
||||||
|
channels := listener.Channels()
|
||||||
|
fmt.Printf("Listening to: %v\n", channels)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Checking Connection Status
|
||||||
|
|
||||||
|
```go
|
||||||
|
if listener.IsConnected() {
|
||||||
|
fmt.Println("Listener is connected")
|
||||||
|
} else {
|
||||||
|
fmt.Println("Listener is disconnected")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with DBManager
|
||||||
|
|
||||||
|
When using the DBManager, you can access the listener through the PostgreSQL provider:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Initialize DBManager
|
||||||
|
mgr, err := dbmanager.NewManager(dbmanager.FromConfig(cfg.DBManager))
|
||||||
|
mgr.Connect(ctx)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
// Get PostgreSQL connection
|
||||||
|
conn, err := mgr.Get("primary")
|
||||||
|
|
||||||
|
// Note: You'll need to cast to the underlying provider type
|
||||||
|
// This requires exposing the provider through the Connection interface
|
||||||
|
// or providing a helper method
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### Cache Invalidation
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener.Listen("cache_invalidation", func(channel, payload string) {
|
||||||
|
// Parse the payload to determine what to invalidate
|
||||||
|
cache.Invalidate(payload)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Real-time Updates
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener.Listen("data_updates", func(channel, payload string) {
|
||||||
|
// Broadcast update to WebSocket clients
|
||||||
|
websocketBroadcast(payload)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Reload
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener.Listen("config_reload", func(channel, payload string) {
|
||||||
|
// Reload application configuration
|
||||||
|
config.Reload()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Distributed Locking
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener.Listen("lock_released", func(channel, payload string) {
|
||||||
|
// Attempt to acquire the lock
|
||||||
|
tryAcquireLock(payload)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Automatic Reconnection
|
||||||
|
|
||||||
|
The listener automatically handles connection failures:
|
||||||
|
|
||||||
|
1. When a connection error is detected, the listener initiates reconnection
|
||||||
|
2. Once reconnected, it automatically resubscribes to all previous channels
|
||||||
|
3. Notification handlers remain active throughout the reconnection process
|
||||||
|
|
||||||
|
No manual intervention is required for reconnection.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Handler Panics
|
||||||
|
|
||||||
|
If a notification handler panics, the panic is recovered and logged. The listener continues to function normally:
|
||||||
|
|
||||||
|
```go
|
||||||
|
listener.Listen("events", func(channel, payload string) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("Handler panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Your event processing logic
|
||||||
|
processEvent(payload)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection Errors
|
||||||
|
|
||||||
|
Connection errors trigger automatic reconnection. Check logs for reconnection events when `EnableLogging` is true.
|
||||||
|
|
||||||
|
## Thread Safety
|
||||||
|
|
||||||
|
All `PostgresListener` methods are thread-safe and can be called concurrently from multiple goroutines.
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
1. **Dedicated Connection**: The listener uses a dedicated PostgreSQL connection separate from the query connection pool
|
||||||
|
2. **Asynchronous Handlers**: Notification handlers run in separate goroutines to avoid blocking
|
||||||
|
3. **Lightweight**: NOTIFY/LISTEN has minimal overhead compared to polling
|
||||||
|
|
||||||
|
## Comparison with Polling
|
||||||
|
|
||||||
|
| Feature | NOTIFY/LISTEN | Polling |
|
||||||
|
|---------|---------------|---------|
|
||||||
|
| Latency | Low (near real-time) | High (depends on poll interval) |
|
||||||
|
| Database Load | Minimal | High (constant queries) |
|
||||||
|
| Scalability | Excellent | Poor |
|
||||||
|
| Complexity | Simple | Moderate |
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
1. **PostgreSQL Only**: This feature is specific to PostgreSQL and not available for other databases
|
||||||
|
2. **No Message Persistence**: Notifications are not stored; if no listener is connected, the message is lost
|
||||||
|
3. **Payload Limit**: Notification payload is limited to 8000 bytes in PostgreSQL
|
||||||
|
4. **No Guaranteed Delivery**: If a listener disconnects, in-flight notifications may be lost
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Keep Handlers Fast**: Notification handlers should be quick; for heavy processing, send work to a queue
|
||||||
|
2. **Use JSON Payloads**: Encode structured data as JSON for easy parsing
|
||||||
|
3. **Handle Errors Gracefully**: Always recover from panics in handlers
|
||||||
|
4. **Close Properly**: Always close the provider to ensure the listener is properly shut down
|
||||||
|
5. **Monitor Connection Status**: Use `IsConnected()` for health checks
|
||||||
|
|
||||||
|
## Example: Real-World Application
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Subscribe to various application events
|
||||||
|
listener, _ := provider.GetListener(ctx)
|
||||||
|
|
||||||
|
// User registration events
|
||||||
|
listener.Listen("user_registered", func(channel, payload string) {
|
||||||
|
var event UserRegisteredEvent
|
||||||
|
json.Unmarshal([]byte(payload), &event)
|
||||||
|
|
||||||
|
// Send welcome email
|
||||||
|
sendWelcomeEmail(event.UserID)
|
||||||
|
|
||||||
|
// Invalidate user count cache
|
||||||
|
cache.Delete("user_count")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Order placement events
|
||||||
|
listener.Listen("order_placed", func(channel, payload string) {
|
||||||
|
var event OrderPlacedEvent
|
||||||
|
json.Unmarshal([]byte(payload), &event)
|
||||||
|
|
||||||
|
// Notify warehouse system
|
||||||
|
warehouse.ProcessOrder(event.OrderID)
|
||||||
|
|
||||||
|
// Update inventory cache
|
||||||
|
cache.Invalidate("inventory:" + event.ProductID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Configuration changes
|
||||||
|
listener.Listen("config_updated", func(channel, payload string) {
|
||||||
|
// Reload configuration from database
|
||||||
|
appConfig.Reload()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Triggering Notifications from SQL
|
||||||
|
|
||||||
|
You can trigger notifications directly from PostgreSQL triggers or functions:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Example trigger to notify on new user
|
||||||
|
CREATE OR REPLACE FUNCTION notify_user_registered()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
PERFORM pg_notify('user_registered',
|
||||||
|
json_build_object(
|
||||||
|
'user_id', NEW.id,
|
||||||
|
'email', NEW.email,
|
||||||
|
'timestamp', NOW()
|
||||||
|
)::text
|
||||||
|
);
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
CREATE TRIGGER user_registered_trigger
|
||||||
|
AFTER INSERT ON users
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION notify_user_registered();
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [PostgreSQL NOTIFY Documentation](https://www.postgresql.org/docs/current/sql-notify.html)
|
||||||
|
- [PostgreSQL LISTEN Documentation](https://www.postgresql.org/docs/current/sql-listen.html)
|
||||||
|
- [pgx Driver Documentation](https://github.com/jackc/pgx)
|
||||||
111
pkg/dbmanager/providers/existing_db.go
Normal file
111
pkg/dbmanager/providers/existing_db.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExistingDBProvider wraps an existing *sql.DB connection
|
||||||
|
// This allows using dbmanager features with a database connection
|
||||||
|
// that was opened outside of the dbmanager package
|
||||||
|
type ExistingDBProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
name string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExistingDBProvider creates a new provider wrapping an existing *sql.DB
|
||||||
|
func NewExistingDBProvider(db *sql.DB, name string) *ExistingDBProvider {
|
||||||
|
return &ExistingDBProvider{
|
||||||
|
db: db,
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect verifies the existing database connection is valid
|
||||||
|
// It does NOT create a new connection, but ensures the existing one works
|
||||||
|
func (p *ExistingDBProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the connection works
|
||||||
|
if err := p.db.PingContext(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to ping existing database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database connection
|
||||||
|
func (p *ExistingDBProvider) Close() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the connection is alive
|
||||||
|
func (p *ExistingDBProvider) HealthCheck(ctx context.Context) error {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
if p.db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.db.PingContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNative returns the wrapped *sql.DB
|
||||||
|
func (p *ExistingDBProvider) GetNative() (*sql.DB, error) {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
if p.db == nil {
|
||||||
|
return nil, fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMongo returns an error since this is a SQL database
|
||||||
|
func (p *ExistingDBProvider) GetMongo() (*mongo.Client, error) {
|
||||||
|
return nil, ErrNotMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection statistics
|
||||||
|
func (p *ExistingDBProvider) Stats() *ConnectionStats {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
|
||||||
|
stats := &ConnectionStats{
|
||||||
|
Name: p.name,
|
||||||
|
Type: "sql", // Generic since we don't know the specific type
|
||||||
|
Connected: p.db != nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.db != nil {
|
||||||
|
dbStats := p.db.Stats()
|
||||||
|
stats.OpenConnections = dbStats.OpenConnections
|
||||||
|
stats.InUse = dbStats.InUse
|
||||||
|
stats.Idle = dbStats.Idle
|
||||||
|
stats.WaitCount = dbStats.WaitCount
|
||||||
|
stats.WaitDuration = dbStats.WaitDuration
|
||||||
|
stats.MaxIdleClosed = dbStats.MaxIdleClosed
|
||||||
|
stats.MaxLifetimeClosed = dbStats.MaxLifetimeClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewExistingDBProvider(t *testing.T) {
|
||||||
|
// Open a SQLite in-memory database
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
if provider == nil {
|
||||||
|
t.Fatal("Expected provider to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.name != "test-db" {
|
||||||
|
t.Errorf("Expected name 'test-db', got '%s'", provider.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_Connect(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Connect should verify the connection works
|
||||||
|
err = provider.Connect(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_Connect_NilDB(t *testing.T) {
|
||||||
|
provider := NewExistingDBProvider(nil, "test-db")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := provider.Connect(ctx, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected Connect to fail with nil database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_GetNative(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
|
||||||
|
nativeDB, err := provider.GetNative()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected GetNative to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nativeDB != db {
|
||||||
|
t.Error("Expected GetNative to return the same database instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_GetNative_NilDB(t *testing.T) {
|
||||||
|
provider := NewExistingDBProvider(nil, "test-db")
|
||||||
|
|
||||||
|
_, err := provider.GetNative()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected GetNative to fail with nil database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_HealthCheck(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = provider.HealthCheck(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_HealthCheck_ClosedDB(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
|
||||||
|
// Close the database
|
||||||
|
db.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = provider.HealthCheck(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected HealthCheck to fail with closed database")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_GetMongo(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
|
||||||
|
_, err = provider.GetMongo()
|
||||||
|
if err != ErrNotMongoDB {
|
||||||
|
t.Errorf("Expected ErrNotMongoDB, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_Stats(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Set some connection pool settings to test stats
|
||||||
|
db.SetMaxOpenConns(10)
|
||||||
|
db.SetMaxIdleConns(5)
|
||||||
|
db.SetConnMaxLifetime(time.Hour)
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
|
||||||
|
stats := provider.Stats()
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("Expected stats to be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.Name != "test-db" {
|
||||||
|
t.Errorf("Expected stats.Name to be 'test-db', got '%s'", stats.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.Type != "sql" {
|
||||||
|
t.Errorf("Expected stats.Type to be 'sql', got '%s'", stats.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stats.Connected {
|
||||||
|
t.Error("Expected stats.Connected to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_Close(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewExistingDBProvider(db, "test-db")
|
||||||
|
|
||||||
|
err = provider.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Close to succeed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the database is closed
|
||||||
|
err = db.Ping()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected database to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExistingDBProvider_Close_NilDB(t *testing.T) {
|
||||||
|
provider := NewExistingDBProvider(nil, "test-db")
|
||||||
|
|
||||||
|
err := provider.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Close to succeed with nil database, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
214
pkg/dbmanager/providers/mongodb.go
Normal file
214
pkg/dbmanager/providers/mongodb.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
"go.mongodb.org/mongo-driver/mongo/options"
|
||||||
|
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MongoProvider implements Provider for MongoDB databases
|
||||||
|
type MongoProvider struct {
|
||||||
|
client *mongo.Client
|
||||||
|
config ConnectionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMongoProvider creates a new MongoDB provider
|
||||||
|
func NewMongoProvider() *MongoProvider {
|
||||||
|
return &MongoProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a MongoDB connection
|
||||||
|
func (p *MongoProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||||
|
// Build DSN
|
||||||
|
dsn, err := cfg.BuildDSN()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client options
|
||||||
|
clientOpts := options.Client().ApplyURI(dsn)
|
||||||
|
|
||||||
|
// Set connection pool size
|
||||||
|
if cfg.GetMaxOpenConns() != nil {
|
||||||
|
maxPoolSize := uint64(*cfg.GetMaxOpenConns())
|
||||||
|
clientOpts.SetMaxPoolSize(maxPoolSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.GetMaxIdleConns() != nil {
|
||||||
|
minPoolSize := uint64(*cfg.GetMaxIdleConns())
|
||||||
|
clientOpts.SetMinPoolSize(minPoolSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set timeouts
|
||||||
|
clientOpts.SetConnectTimeout(cfg.GetConnectTimeout())
|
||||||
|
if cfg.GetQueryTimeout() > 0 {
|
||||||
|
clientOpts.SetTimeout(cfg.GetQueryTimeout())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set read preference if specified
|
||||||
|
if cfg.GetReadPreference() != "" {
|
||||||
|
rp, err := parseReadPreference(cfg.GetReadPreference())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid read preference: %w", err)
|
||||||
|
}
|
||||||
|
clientOpts.SetReadPreference(rp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect with retry logic
|
||||||
|
var client *mongo.Client
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
retryAttempts := 3
|
||||||
|
retryDelay := 1 * time.Second
|
||||||
|
|
||||||
|
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("Retrying MongoDB connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create MongoDB client
|
||||||
|
client, err = mongo.Connect(ctx, clientOpts)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to connect to MongoDB", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping the database to verify connection
|
||||||
|
pingCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||||
|
err = client.Ping(pingCtx, readpref.Primary())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
_ = client.Disconnect(ctx)
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to ping MongoDB", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection successful
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.client = client
|
||||||
|
p.config = cfg
|
||||||
|
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("MongoDB connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the MongoDB connection
|
||||||
|
func (p *MongoProvider) Close() error {
|
||||||
|
if p.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := p.client.Disconnect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close MongoDB connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.GetEnableLogging() {
|
||||||
|
logger.Info("MongoDB connection closed: name=%s", p.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.client = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the MongoDB connection is alive
|
||||||
|
func (p *MongoProvider) HealthCheck(ctx context.Context) error {
|
||||||
|
if p.client == nil {
|
||||||
|
return fmt.Errorf("MongoDB client is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a short timeout for health checks
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.client.Ping(healthCtx, readpref.Primary()); err != nil {
|
||||||
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNative returns an error for MongoDB (not a SQL database)
|
||||||
|
func (p *MongoProvider) GetNative() (*sql.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMongo returns the MongoDB client
|
||||||
|
func (p *MongoProvider) GetMongo() (*mongo.Client, error) {
|
||||||
|
if p.client == nil {
|
||||||
|
return nil, fmt.Errorf("MongoDB client is not initialized")
|
||||||
|
}
|
||||||
|
return p.client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection statistics for MongoDB
|
||||||
|
func (p *MongoProvider) Stats() *ConnectionStats {
|
||||||
|
if p.client == nil {
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "mongodb",
|
||||||
|
Connected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MongoDB doesn't expose detailed connection pool stats like sql.DB
|
||||||
|
// We return basic stats
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "mongodb",
|
||||||
|
Connected: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseReadPreference parses a read preference string into a readpref.ReadPref
|
||||||
|
func parseReadPreference(rp string) (*readpref.ReadPref, error) {
|
||||||
|
switch rp {
|
||||||
|
case "primary":
|
||||||
|
return readpref.Primary(), nil
|
||||||
|
case "primaryPreferred":
|
||||||
|
return readpref.PrimaryPreferred(), nil
|
||||||
|
case "secondary":
|
||||||
|
return readpref.Secondary(), nil
|
||||||
|
case "secondaryPreferred":
|
||||||
|
return readpref.SecondaryPreferred(), nil
|
||||||
|
case "nearest":
|
||||||
|
return readpref.Nearest(), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown read preference: %s", rp)
|
||||||
|
}
|
||||||
|
}
|
||||||
184
pkg/dbmanager/providers/mssql.go
Normal file
184
pkg/dbmanager/providers/mssql.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MSSQLProvider implements Provider for Microsoft SQL Server databases
|
||||||
|
type MSSQLProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
config ConnectionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMSSQLProvider creates a new MSSQL provider
|
||||||
|
func NewMSSQLProvider() *MSSQLProvider {
|
||||||
|
return &MSSQLProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a MSSQL connection
|
||||||
|
func (p *MSSQLProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||||
|
// Build DSN
|
||||||
|
dsn, err := cfg.BuildDSN()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect with retry logic
|
||||||
|
var db *sql.DB
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
retryAttempts := 3 // Default retry attempts
|
||||||
|
retryDelay := 1 * time.Second
|
||||||
|
|
||||||
|
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("Retrying MSSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err = sql.Open("sqlserver", dsn)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to open MSSQL connection", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the connection with context timeout
|
||||||
|
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||||
|
err = db.PingContext(connectCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
db.Close()
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to ping MSSQL database", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection successful
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure connection pool
|
||||||
|
if cfg.GetMaxOpenConns() != nil {
|
||||||
|
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||||
|
}
|
||||||
|
if cfg.GetMaxIdleConns() != nil {
|
||||||
|
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxLifetime() != nil {
|
||||||
|
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxIdleTime() != nil {
|
||||||
|
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = db
|
||||||
|
p.config = cfg
|
||||||
|
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("MSSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the MSSQL connection
|
||||||
|
func (p *MSSQLProvider) Close() error {
|
||||||
|
if p.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.db.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close MSSQL connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.GetEnableLogging() {
|
||||||
|
logger.Info("MSSQL connection closed: name=%s", p.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the MSSQL connection is alive
|
||||||
|
func (p *MSSQLProvider) HealthCheck(ctx context.Context) error {
|
||||||
|
if p.db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a short timeout for health checks
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.db.PingContext(healthCtx); err != nil {
|
||||||
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNative returns the native *sql.DB connection
|
||||||
|
func (p *MSSQLProvider) GetNative() (*sql.DB, error) {
|
||||||
|
if p.db == nil {
|
||||||
|
return nil, fmt.Errorf("database connection is not initialized")
|
||||||
|
}
|
||||||
|
return p.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMongo returns an error for MSSQL (not a MongoDB connection)
|
||||||
|
func (p *MSSQLProvider) GetMongo() (*mongo.Client, error) {
|
||||||
|
return nil, ErrNotMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection pool statistics
|
||||||
|
func (p *MSSQLProvider) Stats() *ConnectionStats {
|
||||||
|
if p.db == nil {
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "mssql",
|
||||||
|
Connected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := p.db.Stats()
|
||||||
|
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "mssql",
|
||||||
|
Connected: true,
|
||||||
|
OpenConnections: stats.OpenConnections,
|
||||||
|
InUse: stats.InUse,
|
||||||
|
Idle: stats.Idle,
|
||||||
|
WaitCount: stats.WaitCount,
|
||||||
|
WaitDuration: stats.WaitDuration,
|
||||||
|
MaxIdleClosed: stats.MaxIdleClosed,
|
||||||
|
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||||
|
}
|
||||||
|
}
|
||||||
231
pkg/dbmanager/providers/postgres.go
Normal file
231
pkg/dbmanager/providers/postgres.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgresProvider implements Provider for PostgreSQL databases
|
||||||
|
type PostgresProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
config ConnectionConfig
|
||||||
|
listener *PostgresListener
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresProvider creates a new PostgreSQL provider
|
||||||
|
func NewPostgresProvider() *PostgresProvider {
|
||||||
|
return &PostgresProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a PostgreSQL connection
|
||||||
|
func (p *PostgresProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||||
|
// Build DSN
|
||||||
|
dsn, err := cfg.BuildDSN()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect with retry logic
|
||||||
|
var db *sql.DB
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
retryAttempts := 3 // Default retry attempts
|
||||||
|
retryDelay := 1 * time.Second
|
||||||
|
|
||||||
|
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("Retrying PostgreSQL connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err = sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to open PostgreSQL connection", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the connection with context timeout
|
||||||
|
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||||
|
err = db.PingContext(connectCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
db.Close()
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to ping PostgreSQL database", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection successful
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect after %d attempts: %w", retryAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure connection pool
|
||||||
|
if cfg.GetMaxOpenConns() != nil {
|
||||||
|
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||||
|
}
|
||||||
|
if cfg.GetMaxIdleConns() != nil {
|
||||||
|
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxLifetime() != nil {
|
||||||
|
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxIdleTime() != nil {
|
||||||
|
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = db
|
||||||
|
p.config = cfg
|
||||||
|
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("PostgreSQL connection established: name=%s, host=%s, database=%s", cfg.GetName(), cfg.GetHost(), cfg.GetDatabase())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the PostgreSQL connection
|
||||||
|
func (p *PostgresProvider) Close() error {
|
||||||
|
// Close listener if it exists
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.listener != nil {
|
||||||
|
if err := p.listener.Close(); err != nil {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to close listener: %w", err)
|
||||||
|
}
|
||||||
|
p.listener = nil
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.db.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.GetEnableLogging() {
|
||||||
|
logger.Info("PostgreSQL connection closed: name=%s", p.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the PostgreSQL connection is alive
|
||||||
|
func (p *PostgresProvider) HealthCheck(ctx context.Context) error {
|
||||||
|
if p.db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a short timeout for health checks
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.db.PingContext(healthCtx); err != nil {
|
||||||
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNative returns the native *sql.DB connection
|
||||||
|
func (p *PostgresProvider) GetNative() (*sql.DB, error) {
|
||||||
|
if p.db == nil {
|
||||||
|
return nil, fmt.Errorf("database connection is not initialized")
|
||||||
|
}
|
||||||
|
return p.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMongo returns an error for PostgreSQL (not a MongoDB connection)
|
||||||
|
func (p *PostgresProvider) GetMongo() (*mongo.Client, error) {
|
||||||
|
return nil, ErrNotMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection pool statistics
|
||||||
|
func (p *PostgresProvider) Stats() *ConnectionStats {
|
||||||
|
if p.db == nil {
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "postgres",
|
||||||
|
Connected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := p.db.Stats()
|
||||||
|
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "postgres",
|
||||||
|
Connected: true,
|
||||||
|
OpenConnections: stats.OpenConnections,
|
||||||
|
InUse: stats.InUse,
|
||||||
|
Idle: stats.Idle,
|
||||||
|
WaitCount: stats.WaitCount,
|
||||||
|
WaitDuration: stats.WaitDuration,
|
||||||
|
MaxIdleClosed: stats.MaxIdleClosed,
|
||||||
|
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetListener returns a PostgreSQL listener for NOTIFY/LISTEN functionality
|
||||||
|
// The listener is lazily initialized on first call and reused for subsequent calls
|
||||||
|
func (p *PostgresProvider) GetListener(ctx context.Context) (*PostgresListener, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
// Return existing listener if already created
|
||||||
|
if p.listener != nil {
|
||||||
|
return p.listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new listener
|
||||||
|
listener := NewPostgresListener(p.config)
|
||||||
|
|
||||||
|
// Connect the listener
|
||||||
|
if err := listener.Connect(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.listener = listener
|
||||||
|
return p.listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff calculates exponential backoff delay
|
||||||
|
func calculateBackoff(attempt int, initial, maxDelay time.Duration) time.Duration {
|
||||||
|
delay := initial * time.Duration(math.Pow(2, float64(attempt)))
|
||||||
|
if delay > maxDelay {
|
||||||
|
delay = maxDelay
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
401
pkg/dbmanager/providers/postgres_listener.go
Normal file
401
pkg/dbmanager/providers/postgres_listener.go
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationHandler is called when a notification is received
|
||||||
|
type NotificationHandler func(channel string, payload string)
|
||||||
|
|
||||||
|
// PostgresListener manages PostgreSQL LISTEN/NOTIFY functionality
|
||||||
|
type PostgresListener struct {
|
||||||
|
config ConnectionConfig
|
||||||
|
conn *pgx.Conn
|
||||||
|
|
||||||
|
// Channel subscriptions
|
||||||
|
channels map[string]NotificationHandler
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Lifecycle management
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
closed bool
|
||||||
|
closeMu sync.Mutex
|
||||||
|
reconnectC chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresListener creates a new PostgreSQL listener
|
||||||
|
func NewPostgresListener(cfg ConnectionConfig) *PostgresListener {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &PostgresListener{
|
||||||
|
config: cfg,
|
||||||
|
channels: make(map[string]NotificationHandler),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
reconnectC: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a dedicated connection for listening
|
||||||
|
func (l *PostgresListener) Connect(ctx context.Context) error {
|
||||||
|
dsn, err := l.config.BuildDSN()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse connection config
|
||||||
|
connConfig, err := pgx.ParseConfig(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse connection config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect with retry logic
|
||||||
|
var conn *pgx.Conn
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
retryAttempts := 3
|
||||||
|
retryDelay := 1 * time.Second
|
||||||
|
|
||||||
|
for attempt := 0; attempt < retryAttempts; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := calculateBackoff(attempt, retryDelay, 10*time.Second)
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("Retrying PostgreSQL listener connection: attempt=%d/%d, delay=%v", attempt+1, retryAttempts, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = pgx.ConnectConfig(ctx, connConfig)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to connect PostgreSQL listener", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the connection
|
||||||
|
if err = conn.Ping(ctx); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
conn.Close(ctx)
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to ping PostgreSQL listener", "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection successful
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect listener after %d attempts: %w", retryAttempts, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
l.conn = conn
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
// Start notification handler
|
||||||
|
go l.handleNotifications()
|
||||||
|
|
||||||
|
// Start reconnection handler
|
||||||
|
go l.handleReconnection()
|
||||||
|
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("PostgreSQL listener connected: name=%s", l.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen subscribes to a PostgreSQL notification channel
|
||||||
|
func (l *PostgresListener) Listen(channel string, handler NotificationHandler) error {
|
||||||
|
l.closeMu.Lock()
|
||||||
|
if l.closed {
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
return fmt.Errorf("listener is closed")
|
||||||
|
}
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
if l.conn == nil {
|
||||||
|
return fmt.Errorf("listener connection is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute LISTEN command
|
||||||
|
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("LISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen on channel %s: %w", channel, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the handler
|
||||||
|
l.channels[channel] = handler
|
||||||
|
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("Listening on channel: name=%s, channel=%s", l.config.GetName(), channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlisten unsubscribes from a PostgreSQL notification channel
|
||||||
|
func (l *PostgresListener) Unlisten(channel string) error {
|
||||||
|
l.closeMu.Lock()
|
||||||
|
if l.closed {
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
return fmt.Errorf("listener is closed")
|
||||||
|
}
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
if l.conn == nil {
|
||||||
|
return fmt.Errorf("listener connection is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute UNLISTEN command
|
||||||
|
_, err := l.conn.Exec(l.ctx, fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unlisten from channel %s: %w", channel, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the handler
|
||||||
|
delete(l.channels, channel)
|
||||||
|
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("Unlistened from channel: name=%s, channel=%s", l.config.GetName(), channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify sends a notification to a PostgreSQL channel
|
||||||
|
func (l *PostgresListener) Notify(ctx context.Context, channel string, payload string) error {
|
||||||
|
l.closeMu.Lock()
|
||||||
|
if l.closed {
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
return fmt.Errorf("listener is closed")
|
||||||
|
}
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
|
||||||
|
l.mu.RLock()
|
||||||
|
conn := l.conn
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("listener connection is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute NOTIFY command
|
||||||
|
_, err := conn.Exec(ctx, "SELECT pg_notify($1, $2)", channel, payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to notify channel %s: %w", channel, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the listener and all subscriptions
|
||||||
|
func (l *PostgresListener) Close() error {
|
||||||
|
l.closeMu.Lock()
|
||||||
|
if l.closed {
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
l.closed = true
|
||||||
|
l.closeMu.Unlock()
|
||||||
|
|
||||||
|
// Cancel context to stop background goroutines
|
||||||
|
l.cancel()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
if l.conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlisten from all channels
|
||||||
|
for channel := range l.channels {
|
||||||
|
_, _ = l.conn.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", pgx.Identifier{channel}.Sanitize()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close connection
|
||||||
|
err := l.conn.Close(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close listener connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.conn = nil
|
||||||
|
l.channels = make(map[string]NotificationHandler)
|
||||||
|
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("PostgreSQL listener closed: name=%s", l.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleNotifications processes incoming notifications
|
||||||
|
func (l *PostgresListener) handleNotifications() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mu.RLock()
|
||||||
|
conn := l.conn
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
// Connection not available, wait for reconnection
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for notification with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(l.ctx, 5*time.Second)
|
||||||
|
notification, err := conn.WaitForNotification(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Check if context was cancelled
|
||||||
|
if l.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a connection error
|
||||||
|
if pgconn.Timeout(err) {
|
||||||
|
// Timeout is normal, continue waiting
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection error, trigger reconnection
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Warn("Notification error, triggering reconnection", "error", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case l.reconnectC <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process notification
|
||||||
|
l.mu.RLock()
|
||||||
|
handler, exists := l.channels[notification.Channel]
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists && handler != nil {
|
||||||
|
// Call handler in a goroutine to avoid blocking
|
||||||
|
go func(ch, payload string) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Error("Notification handler panic: channel=%s, error=%v", ch, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
handler(ch, payload)
|
||||||
|
}(notification.Channel, notification.Payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleReconnection manages automatic reconnection
|
||||||
|
func (l *PostgresListener) handleReconnection() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-l.reconnectC:
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("Attempting to reconnect listener: name=%s", l.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close existing connection
|
||||||
|
l.mu.Lock()
|
||||||
|
if l.conn != nil {
|
||||||
|
l.conn.Close(context.Background())
|
||||||
|
l.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save current subscriptions
|
||||||
|
channels := make(map[string]NotificationHandler)
|
||||||
|
for ch, handler := range l.channels {
|
||||||
|
channels[ch] = handler
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
// Attempt reconnection
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
err := l.Connect(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Error("Failed to reconnect listener: name=%s, error=%v", l.config.GetName(), err)
|
||||||
|
}
|
||||||
|
// Retry after delay
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
select {
|
||||||
|
case l.reconnectC <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resubscribe to all channels
|
||||||
|
for channel, handler := range channels {
|
||||||
|
if err := l.Listen(channel, handler); err != nil {
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Error("Failed to resubscribe to channel: name=%s, channel=%s, error=%v", l.config.GetName(), channel, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.config.GetEnableLogging() {
|
||||||
|
logger.Info("Listener reconnected successfully: name=%s", l.config.GetName())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected returns true if the listener is connected
|
||||||
|
func (l *PostgresListener) IsConnected() bool {
|
||||||
|
l.mu.RLock()
|
||||||
|
defer l.mu.RUnlock()
|
||||||
|
return l.conn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channels returns the list of channels currently being listened to
|
||||||
|
func (l *PostgresListener) Channels() []string {
|
||||||
|
l.mu.RLock()
|
||||||
|
defer l.mu.RUnlock()
|
||||||
|
|
||||||
|
channels := make([]string, 0, len(l.channels))
|
||||||
|
for ch := range l.channels {
|
||||||
|
channels = append(channels, ch)
|
||||||
|
}
|
||||||
|
return channels
|
||||||
|
}
|
||||||
228
pkg/dbmanager/providers/postgres_listener_example_test.go
Normal file
228
pkg/dbmanager/providers/postgres_listener_example_test.go
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
package providers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExamplePostgresListener_basic demonstrates basic LISTEN/NOTIFY usage
|
||||||
|
func ExamplePostgresListener_basic() {
|
||||||
|
// Create a connection config
|
||||||
|
cfg := &dbmanager.ConnectionConfig{
|
||||||
|
Name: "example",
|
||||||
|
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "password",
|
||||||
|
Database: "testdb",
|
||||||
|
ConnectTimeout: 10 * time.Second,
|
||||||
|
EnableLogging: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and connect PostgreSQL provider
|
||||||
|
provider := providers.NewPostgresProvider()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||||
|
}
|
||||||
|
defer provider.Close()
|
||||||
|
|
||||||
|
// Get listener
|
||||||
|
listener, err := provider.GetListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to a channel with a handler
|
||||||
|
err = listener.Listen("user_events", func(channel, payload string) {
|
||||||
|
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to listen: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a notification
|
||||||
|
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to notify: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for notification to be processed
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Unsubscribe from the channel
|
||||||
|
if err := listener.Unlisten("user_events"); err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExamplePostgresListener_multipleChannels demonstrates listening to multiple channels
|
||||||
|
func ExamplePostgresListener_multipleChannels() {
|
||||||
|
cfg := &dbmanager.ConnectionConfig{
|
||||||
|
Name: "example",
|
||||||
|
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "password",
|
||||||
|
Database: "testdb",
|
||||||
|
ConnectTimeout: 10 * time.Second,
|
||||||
|
EnableLogging: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := providers.NewPostgresProvider()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||||
|
}
|
||||||
|
defer provider.Close()
|
||||||
|
|
||||||
|
listener, err := provider.GetListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen to multiple channels
|
||||||
|
channels := []string{"orders", "payments", "notifications"}
|
||||||
|
for _, ch := range channels {
|
||||||
|
channel := ch // Capture for closure
|
||||||
|
err := listener.Listen(channel, func(ch, payload string) {
|
||||||
|
fmt.Printf("[%s] %s\n", ch, payload)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send notifications to different channels
|
||||||
|
listener.Notify(ctx, "orders", "New order #12345")
|
||||||
|
listener.Notify(ctx, "payments", "Payment received $99.99")
|
||||||
|
listener.Notify(ctx, "notifications", "Welcome email sent")
|
||||||
|
|
||||||
|
// Wait for notifications
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check active channels
|
||||||
|
activeChannels := listener.Channels()
|
||||||
|
fmt.Printf("Listening to %d channels: %v\n", len(activeChannels), activeChannels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExamplePostgresListener_withDBManager demonstrates usage with DBManager
|
||||||
|
func ExamplePostgresListener_withDBManager() {
|
||||||
|
// This example shows how to use the listener with the full DBManager
|
||||||
|
|
||||||
|
// Assume we have a DBManager instance and get a connection
|
||||||
|
// conn, _ := dbMgr.Get("primary")
|
||||||
|
|
||||||
|
// Get the underlying provider (this would need to be exposed via the Connection interface)
|
||||||
|
// For now, this is a conceptual example
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create provider directly for demonstration
|
||||||
|
cfg := &dbmanager.ConnectionConfig{
|
||||||
|
Name: "primary",
|
||||||
|
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "password",
|
||||||
|
Database: "myapp",
|
||||||
|
ConnectTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := providers.NewPostgresProvider()
|
||||||
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer provider.Close()
|
||||||
|
|
||||||
|
// Get listener
|
||||||
|
listener, err := provider.GetListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to application events
|
||||||
|
listener.Listen("cache_invalidation", func(channel, payload string) {
|
||||||
|
fmt.Printf("Cache invalidation request: %s\n", payload)
|
||||||
|
// Handle cache invalidation logic here
|
||||||
|
})
|
||||||
|
|
||||||
|
listener.Listen("config_reload", func(channel, payload string) {
|
||||||
|
fmt.Printf("Configuration reload request: %s\n", payload)
|
||||||
|
// Handle configuration reload logic here
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate receiving notifications
|
||||||
|
listener.Notify(ctx, "cache_invalidation", "user:123")
|
||||||
|
listener.Notify(ctx, "config_reload", "database")
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExamplePostgresListener_errorHandling demonstrates error handling and reconnection
|
||||||
|
func ExamplePostgresListener_errorHandling() {
|
||||||
|
cfg := &dbmanager.ConnectionConfig{
|
||||||
|
Name: "example",
|
||||||
|
Type: dbmanager.DatabaseTypePostgreSQL,
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "password",
|
||||||
|
Database: "testdb",
|
||||||
|
ConnectTimeout: 10 * time.Second,
|
||||||
|
EnableLogging: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := providers.NewPostgresProvider()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||||
|
}
|
||||||
|
defer provider.Close()
|
||||||
|
|
||||||
|
listener, err := provider.GetListener(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The listener automatically reconnects if the connection is lost
|
||||||
|
// Subscribe with error handling in the callback
|
||||||
|
err = listener.Listen("critical_events", func(channel, payload string) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
fmt.Printf("Handler panic recovered: %v\n", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Process the event
|
||||||
|
fmt.Printf("Processing critical event: %s\n", payload)
|
||||||
|
|
||||||
|
// If processing fails, the panic will be caught by the defer above
|
||||||
|
// The listener will continue to function normally
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to listen: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if listener is connected
|
||||||
|
if listener.IsConnected() {
|
||||||
|
fmt.Println("Listener is connected and ready")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a notification
|
||||||
|
listener.Notify(ctx, "critical_events", "system_alert")
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
83
pkg/dbmanager/providers/provider.go
Normal file
83
pkg/dbmanager/providers/provider.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common errors
|
||||||
|
var (
|
||||||
|
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||||
|
ErrNotSQLDatabase = errors.New("not a SQL database")
|
||||||
|
|
||||||
|
// ErrNotMongoDB is returned when attempting MongoDB operations on a non-MongoDB connection
|
||||||
|
ErrNotMongoDB = errors.New("not a MongoDB connection")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionStats contains statistics about a database connection
|
||||||
|
type ConnectionStats struct {
|
||||||
|
Name string
|
||||||
|
Type string // Database type as string to avoid circular dependency
|
||||||
|
Connected bool
|
||||||
|
LastHealthCheck time.Time
|
||||||
|
HealthCheckStatus string
|
||||||
|
|
||||||
|
// SQL connection pool stats
|
||||||
|
OpenConnections int
|
||||||
|
InUse int
|
||||||
|
Idle int
|
||||||
|
WaitCount int64
|
||||||
|
WaitDuration time.Duration
|
||||||
|
MaxIdleClosed int64
|
||||||
|
MaxLifetimeClosed int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionConfig is a minimal interface for configuration
|
||||||
|
// The actual implementation is in dbmanager package
|
||||||
|
type ConnectionConfig interface {
|
||||||
|
BuildDSN() (string, error)
|
||||||
|
GetName() string
|
||||||
|
GetType() string
|
||||||
|
GetHost() string
|
||||||
|
GetPort() int
|
||||||
|
GetUser() string
|
||||||
|
GetPassword() string
|
||||||
|
GetDatabase() string
|
||||||
|
GetFilePath() string
|
||||||
|
GetConnectTimeout() time.Duration
|
||||||
|
GetQueryTimeout() time.Duration
|
||||||
|
GetEnableLogging() bool
|
||||||
|
GetEnableMetrics() bool
|
||||||
|
GetMaxOpenConns() *int
|
||||||
|
GetMaxIdleConns() *int
|
||||||
|
GetConnMaxLifetime() *time.Duration
|
||||||
|
GetConnMaxIdleTime() *time.Duration
|
||||||
|
GetReadPreference() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider creates and manages the underlying database connection
|
||||||
|
type Provider interface {
|
||||||
|
// Connect establishes the database connection
|
||||||
|
Connect(ctx context.Context, cfg ConnectionConfig) error
|
||||||
|
|
||||||
|
// Close closes the connection
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// HealthCheck verifies the connection is alive
|
||||||
|
HealthCheck(ctx context.Context) error
|
||||||
|
|
||||||
|
// GetNative returns the native *sql.DB (SQL databases only)
|
||||||
|
// Returns an error for non-SQL databases
|
||||||
|
GetNative() (*sql.DB, error)
|
||||||
|
|
||||||
|
// GetMongo returns the MongoDB client (MongoDB only)
|
||||||
|
// Returns an error for non-MongoDB databases
|
||||||
|
GetMongo() (*mongo.Client, error)
|
||||||
|
|
||||||
|
// Stats returns connection statistics
|
||||||
|
Stats() *ConnectionStats
|
||||||
|
}
|
||||||
181
pkg/dbmanager/providers/sqlite.go
Normal file
181
pkg/dbmanager/providers/sqlite.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLiteProvider implements Provider for SQLite databases
|
||||||
|
type SQLiteProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
config ConnectionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLiteProvider creates a new SQLite provider
|
||||||
|
func NewSQLiteProvider() *SQLiteProvider {
|
||||||
|
return &SQLiteProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a SQLite connection
|
||||||
|
func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||||
|
// Build DSN
|
||||||
|
dsn, err := cfg.BuildDSN()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build DSN: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err := sql.Open("sqlite", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open SQLite connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the connection with context timeout
|
||||||
|
connectCtx, cancel := context.WithTimeout(ctx, cfg.GetConnectTimeout())
|
||||||
|
err = db.PingContext(connectCtx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
db.Close()
|
||||||
|
return fmt.Errorf("failed to ping SQLite database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure connection pool
|
||||||
|
// Note: SQLite works best with MaxOpenConns=1 for write operations
|
||||||
|
// but can handle multiple readers
|
||||||
|
if cfg.GetMaxOpenConns() != nil {
|
||||||
|
db.SetMaxOpenConns(*cfg.GetMaxOpenConns())
|
||||||
|
} else {
|
||||||
|
// Default to 1 for SQLite to avoid "database is locked" errors
|
||||||
|
db.SetMaxOpenConns(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.GetMaxIdleConns() != nil {
|
||||||
|
db.SetMaxIdleConns(*cfg.GetMaxIdleConns())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxLifetime() != nil {
|
||||||
|
db.SetConnMaxLifetime(*cfg.GetConnMaxLifetime())
|
||||||
|
}
|
||||||
|
if cfg.GetConnMaxIdleTime() != nil {
|
||||||
|
db.SetConnMaxIdleTime(*cfg.GetConnMaxIdleTime())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable WAL mode for better concurrent access
|
||||||
|
_, err = db.ExecContext(ctx, "PRAGMA journal_mode=WAL")
|
||||||
|
if err != nil {
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to enable WAL mode for SQLite", "error", err)
|
||||||
|
}
|
||||||
|
// Don't fail connection if WAL mode cannot be enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set busy timeout to handle locked database (minimum 2 minutes = 120000ms)
|
||||||
|
busyTimeout := cfg.GetQueryTimeout().Milliseconds()
|
||||||
|
if busyTimeout < 120000 {
|
||||||
|
busyTimeout = 120000 // Enforce minimum of 2 minutes
|
||||||
|
}
|
||||||
|
_, err = db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout=%d", busyTimeout))
|
||||||
|
if err != nil {
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = db
|
||||||
|
p.config = cfg
|
||||||
|
|
||||||
|
if cfg.GetEnableLogging() {
|
||||||
|
logger.Info("SQLite connection established: name=%s, filepath=%s", cfg.GetName(), cfg.GetFilePath())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the SQLite connection
|
||||||
|
func (p *SQLiteProvider) Close() error {
|
||||||
|
if p.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.db.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close SQLite connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.GetEnableLogging() {
|
||||||
|
logger.Info("SQLite connection closed: name=%s", p.config.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.db = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck verifies the SQLite connection is alive
|
||||||
|
func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||||
|
if p.db == nil {
|
||||||
|
return fmt.Errorf("database connection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a short timeout for health checks
|
||||||
|
healthCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Execute a simple query to verify the database is accessible
|
||||||
|
var result int
|
||||||
|
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != 1 {
|
||||||
|
return fmt.Errorf("health check returned unexpected result: %d", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNative returns the native *sql.DB connection
|
||||||
|
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||||
|
if p.db == nil {
|
||||||
|
return nil, fmt.Errorf("database connection is not initialized")
|
||||||
|
}
|
||||||
|
return p.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMongo returns an error for SQLite (not a MongoDB connection)
|
||||||
|
func (p *SQLiteProvider) GetMongo() (*mongo.Client, error) {
|
||||||
|
return nil, ErrNotMongoDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns connection pool statistics
|
||||||
|
func (p *SQLiteProvider) Stats() *ConnectionStats {
|
||||||
|
if p.db == nil {
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "sqlite",
|
||||||
|
Connected: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := p.db.Stats()
|
||||||
|
|
||||||
|
return &ConnectionStats{
|
||||||
|
Name: p.config.GetName(),
|
||||||
|
Type: "sqlite",
|
||||||
|
Connected: true,
|
||||||
|
OpenConnections: stats.OpenConnections,
|
||||||
|
InUse: stats.InUse,
|
||||||
|
Idle: stats.Idle,
|
||||||
|
WaitCount: stats.WaitCount,
|
||||||
|
WaitDuration: stats.WaitDuration,
|
||||||
|
MaxIdleClosed: stats.MaxIdleClosed,
|
||||||
|
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
@@ -84,7 +86,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
// Create local copy to avoid modifying the captured parameter across requests
|
// Create local copy to avoid modifying the captured parameter across requests
|
||||||
sqlquery := sqlquery
|
sqlquery := sqlquery
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var dbobjlist []map[string]interface{}
|
var dbobjlist []map[string]interface{}
|
||||||
@@ -123,27 +125,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
ComplexAPI: complexAPI,
|
ComplexAPI: complexAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute BeforeQueryList hook
|
|
||||||
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
|
||||||
logger.Error("BeforeQueryList hook failed: %v", err)
|
|
||||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if hook aborted the operation
|
|
||||||
if hookCtx.Abort {
|
|
||||||
if hookCtx.AbortCode == 0 {
|
|
||||||
hookCtx.AbortCode = http.StatusBadRequest
|
|
||||||
}
|
|
||||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified SQL query and variables from hooks
|
|
||||||
sqlquery = hookCtx.SQLQuery
|
|
||||||
variables = hookCtx.Variables
|
|
||||||
// complexAPI = hookCtx.ComplexAPI
|
|
||||||
|
|
||||||
// Extract input variables from SQL query (placeholders like [variable])
|
// Extract input variables from SQL query (placeholders like [variable])
|
||||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||||
|
|
||||||
@@ -203,6 +184,27 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
|
|
||||||
// Execute query within transaction
|
// Execute query within transaction
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Set transaction in hook context for hooks to use
|
||||||
|
hookCtx.Tx = tx
|
||||||
|
|
||||||
|
// Execute BeforeQueryList hook (inside transaction)
|
||||||
|
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeQueryList hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook aborted the operation
|
||||||
|
if hookCtx.Abort {
|
||||||
|
if hookCtx.AbortCode == 0 {
|
||||||
|
hookCtx.AbortCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||||
|
return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified SQL query from hook
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
sqlqueryCnt := sqlquery
|
sqlqueryCnt := sqlquery
|
||||||
|
|
||||||
// Parse sorting and pagination parameters
|
// Parse sorting and pagination parameters
|
||||||
@@ -286,6 +288,21 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
}
|
}
|
||||||
total = hookCtx.Total
|
total = hookCtx.Total
|
||||||
|
|
||||||
|
// Execute AfterQueryList hook (inside transaction)
|
||||||
|
hookCtx.Result = dbobjlist
|
||||||
|
hookCtx.Total = total
|
||||||
|
hookCtx.Error = nil
|
||||||
|
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterQueryList hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||||
|
dbobjlist = modifiedResult
|
||||||
|
}
|
||||||
|
total = hookCtx.Total
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -294,21 +311,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterQueryList hook
|
|
||||||
hookCtx.Result = dbobjlist
|
|
||||||
hookCtx.Total = total
|
|
||||||
hookCtx.Error = err
|
|
||||||
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterQueryList hook failed: %v", err)
|
|
||||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Use potentially modified result from hook
|
|
||||||
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
|
||||||
dbobjlist = modifiedResult
|
|
||||||
}
|
|
||||||
total = hookCtx.Total
|
|
||||||
|
|
||||||
// Set response headers
|
// Set response headers
|
||||||
respOffset := 0
|
respOffset := 0
|
||||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||||
@@ -423,7 +425,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
// Create local copy to avoid modifying the captured parameter across requests
|
// Create local copy to avoid modifying the captured parameter across requests
|
||||||
sqlquery := sqlquery
|
sqlquery := sqlquery
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
propQry := make(map[string]string)
|
propQry := make(map[string]string)
|
||||||
@@ -459,26 +461,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
ComplexAPI: complexAPI,
|
ComplexAPI: complexAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute BeforeQuery hook
|
|
||||||
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
|
||||||
logger.Error("BeforeQuery hook failed: %v", err)
|
|
||||||
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if hook aborted the operation
|
|
||||||
if hookCtx.Abort {
|
|
||||||
if hookCtx.AbortCode == 0 {
|
|
||||||
hookCtx.AbortCode = http.StatusBadRequest
|
|
||||||
}
|
|
||||||
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified SQL query and variables from hooks
|
|
||||||
sqlquery = hookCtx.SQLQuery
|
|
||||||
variables = hookCtx.Variables
|
|
||||||
|
|
||||||
// Extract input variables from SQL query
|
// Extract input variables from SQL query
|
||||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||||
|
|
||||||
@@ -522,10 +504,17 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
if strings.HasPrefix(kLower, "x-fieldfilter-") {
|
if strings.HasPrefix(kLower, "x-fieldfilter-") {
|
||||||
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
|
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
|
||||||
if strings.Contains(strings.ToLower(sqlquery), colname) {
|
if strings.Contains(strings.ToLower(sqlquery), colname) {
|
||||||
if val == "" || val == "0" {
|
switch val {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
case "0":
|
||||||
} else {
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
|
||||||
|
case "":
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
|
||||||
|
default:
|
||||||
|
if IsNumeric(val) {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||||
|
} else {
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -547,6 +536,28 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
|
|
||||||
// Execute query within transaction
|
// Execute query within transaction
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Set transaction in hook context for hooks to use
|
||||||
|
hookCtx.Tx = tx
|
||||||
|
|
||||||
|
// Execute BeforeQuery hook (inside transaction)
|
||||||
|
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeQuery hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook aborted the operation
|
||||||
|
if hookCtx.Abort {
|
||||||
|
if hookCtx.AbortCode == 0 {
|
||||||
|
hookCtx.AbortCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||||
|
return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified SQL query from hook
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
|
|
||||||
// Execute BeforeSQLExec hook
|
// Execute BeforeSQLExec hook
|
||||||
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
||||||
logger.Error("BeforeSQLExec hook failed: %v", err)
|
logger.Error("BeforeSQLExec hook failed: %v", err)
|
||||||
@@ -579,6 +590,19 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
dbobj = modifiedResult
|
dbobj = modifiedResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterQuery hook (inside transaction)
|
||||||
|
hookCtx.Result = dbobj
|
||||||
|
hookCtx.Error = nil
|
||||||
|
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterQuery hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||||
|
dbobj = modifiedResult
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -587,19 +611,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterQuery hook
|
|
||||||
hookCtx.Result = dbobj
|
|
||||||
hookCtx.Error = err
|
|
||||||
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterQuery hook failed: %v", err)
|
|
||||||
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Use potentially modified result from hook
|
|
||||||
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
|
||||||
dbobj = modifiedResult
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute BeforeResponse hook
|
// Execute BeforeResponse hook
|
||||||
hookCtx.Result = dbobj
|
hookCtx.Result = dbobj
|
||||||
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
||||||
@@ -662,7 +673,10 @@ func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables ma
|
|||||||
for k, v := range pathVars {
|
for k, v := range pathVars {
|
||||||
kword := fmt.Sprintf("[%s]", k)
|
kword := fmt.Sprintf("[%s]", k)
|
||||||
if strings.Contains(sqlquery, kword) {
|
if strings.Contains(sqlquery, kword) {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v))
|
// Sanitize the value before replacing
|
||||||
|
vStr := fmt.Sprintf("%v", v)
|
||||||
|
sanitized := ValidSQL(vStr, "colvalue")
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||||
}
|
}
|
||||||
variables[k] = v
|
variables[k] = v
|
||||||
|
|
||||||
@@ -690,7 +704,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
|||||||
// Replace in SQL if placeholder exists
|
// Replace in SQL if placeholder exists
|
||||||
if strings.Contains(sqlquery, kword) && len(val) > 0 {
|
if strings.Contains(sqlquery, kword) && len(val) > 0 {
|
||||||
if strings.HasPrefix(parmk, "p-") {
|
if strings.HasPrefix(parmk, "p-") {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
// Sanitize the parameter value before replacing
|
||||||
|
sanitized := ValidSQL(val, "colvalue")
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -702,15 +718,36 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
|||||||
// Apply filters if allowed
|
// Apply filters if allowed
|
||||||
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
||||||
if len(parmv) > 1 {
|
if len(parmv) > 1 {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
|
// Sanitize each value in the IN clause with appropriate quoting
|
||||||
|
sanitizedValues := make([]string, len(parmv))
|
||||||
|
for i, v := range parmv {
|
||||||
|
if IsNumeric(v) {
|
||||||
|
// Numeric values don't need quotes
|
||||||
|
sanitizedValues[i] = ValidSQL(v, "colvalue")
|
||||||
|
} else {
|
||||||
|
// String values need quotes
|
||||||
|
sanitized := ValidSQL(v, "colvalue")
|
||||||
|
sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ",")))
|
||||||
} else {
|
} else {
|
||||||
if strings.Contains(val, "match=") {
|
if strings.Contains(val, "match=") {
|
||||||
colval := strings.ReplaceAll(val, "match=", "")
|
colval := strings.ReplaceAll(val, "match=", "")
|
||||||
|
// Escape single quotes and backslashes for LIKE patterns
|
||||||
|
// But don't escape wildcards % and _ which are intentional
|
||||||
|
colval = strings.ReplaceAll(colval, "\\", "\\\\")
|
||||||
|
colval = strings.ReplaceAll(colval, "'", "''")
|
||||||
if colval != "*" {
|
if colval != "*" {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
||||||
}
|
}
|
||||||
} else if val == "" || val == "0" {
|
} else if val == "" || val == "0" {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
// For empty/zero values, treat as literal 0 or empty string with quotes
|
||||||
|
if val == "0" {
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
|
||||||
|
} else {
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if IsNumeric(val) {
|
if IsNumeric(val) {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||||
@@ -743,16 +780,25 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
|||||||
|
|
||||||
kword := fmt.Sprintf("[%s]", k)
|
kword := fmt.Sprintf("[%s]", k)
|
||||||
if strings.Contains(sqlquery, kword) {
|
if strings.Contains(sqlquery, kword) {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
// Sanitize the header value before replacing
|
||||||
|
sanitized := ValidSQL(val, "colvalue")
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle special headers
|
// Handle special headers
|
||||||
if strings.Contains(k, "x-fieldfilter-") {
|
if strings.Contains(k, "x-fieldfilter-") {
|
||||||
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
|
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
|
||||||
if val == "" || val == "0" {
|
switch val {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
case "0":
|
||||||
} else {
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
|
||||||
|
case "":
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
|
||||||
|
default:
|
||||||
|
if IsNumeric(val) {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||||
|
} else {
|
||||||
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -782,12 +828,15 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
|||||||
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
||||||
if strings.Contains(sqlquery, "[p_meta_default]") {
|
if strings.Contains(sqlquery, "[p_meta_default]") {
|
||||||
data, _ := json.Marshal(metainfo)
|
data, _ := json.Marshal(metainfo)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
dataStr := strings.ReplaceAll(string(data), "$META$", "/*META*/")
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("$META$%s$META$::jsonb", dataStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[json_variables]") {
|
if strings.Contains(sqlquery, "[json_variables]") {
|
||||||
data, _ := json.Marshal(variables)
|
data, _ := json.Marshal(variables)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
dataStr := strings.ReplaceAll(string(data), "$VAR$", "/*VAR*/")
|
||||||
|
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("$VAR$%s$VAR$::jsonb", dataStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[rid_user]") {
|
if strings.Contains(sqlquery, "[rid_user]") {
|
||||||
@@ -795,7 +844,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[user]") {
|
if strings.Contains(sqlquery, "[user]") {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
|
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("$USR$%s$USR$", strings.ReplaceAll(userCtx.UserName, "$USR$", "/*USR*/")))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[rid_session]") {
|
if strings.Contains(sqlquery, "[rid_session]") {
|
||||||
@@ -806,7 +855,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[method]") {
|
if strings.Contains(sqlquery, "[method]") {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
|
sqlquery = strings.ReplaceAll(sqlquery, "[method]", fmt.Sprintf("$M$%s$M$", strings.ReplaceAll(r.Method, "$M$", "/*M*/")))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[post_body]") {
|
if strings.Contains(sqlquery, "[post_body]") {
|
||||||
@@ -819,7 +868,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
|
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("$PBODY$%s$PBODY$", strings.ReplaceAll(bodystr, "$PBODY$", "/*PBODY*/")))
|
||||||
}
|
}
|
||||||
|
|
||||||
return sqlquery
|
return sqlquery
|
||||||
@@ -859,19 +908,23 @@ func ValidSQL(input, mode string) string {
|
|||||||
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
|
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
|
||||||
return reg.ReplaceAllString(input, "")
|
return reg.ReplaceAllString(input, "")
|
||||||
case "colvalue":
|
case "colvalue":
|
||||||
// For column values, escape single quotes
|
// For column values, escape single quotes and backslashes
|
||||||
return strings.ReplaceAll(input, "'", "''")
|
// Note: Backslashes must be escaped first, then single quotes
|
||||||
|
result := strings.ReplaceAll(input, "\\", "\\\\")
|
||||||
|
result = strings.ReplaceAll(result, "'", "''")
|
||||||
|
return result
|
||||||
case "select":
|
case "select":
|
||||||
// For SELECT clauses, be more permissive but still safe
|
// For SELECT clauses, be more permissive but still safe
|
||||||
// Remove semicolons and common SQL injection patterns
|
// Remove semicolons and common SQL injection patterns (case-insensitive)
|
||||||
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
|
dangerous := []string{
|
||||||
result := input
|
";", "--", "/\\*", "\\*/", "xp_", "sp_",
|
||||||
for _, d := range dangerous {
|
"drop ", "delete ", "truncate ", "update ", "insert ",
|
||||||
result = strings.ReplaceAll(result, d, "")
|
"exec ", "execute ", "union ", "declare ", "alter ", "create ",
|
||||||
result = strings.ReplaceAll(result, strings.ToLower(d), "")
|
|
||||||
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
|
|
||||||
}
|
}
|
||||||
return result
|
// Build a single regex pattern with all dangerous keywords
|
||||||
|
pattern := "(?i)(" + strings.Join(dangerous, "|") + ")"
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
return re.ReplaceAllString(input, "")
|
||||||
default:
|
default:
|
||||||
return input
|
return input
|
||||||
}
|
}
|
||||||
@@ -1048,9 +1101,25 @@ func normalizePostgresValue(value interface{}) interface{} {
|
|||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Recursively normalize nested maps
|
// Recursively normalize nested maps
|
||||||
return normalizePostgresTypes(v)
|
return normalizePostgresTypes(v)
|
||||||
|
case string:
|
||||||
|
var jsonObj interface{}
|
||||||
|
if err := json.Unmarshal([]byte(v), &jsonObj); err == nil {
|
||||||
|
// It's valid JSON, return as json.RawMessage so it's not double-encoded
|
||||||
|
return json.RawMessage(v)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
case uuid.UUID:
|
||||||
|
return v.String()
|
||||||
|
case time.Time:
|
||||||
|
return v.Format(time.RFC3339)
|
||||||
|
case bool, int, int8, int16, int32, int64, float32, float64, uint, uint8, uint16, uint32, uint64:
|
||||||
|
return v
|
||||||
default:
|
default:
|
||||||
// For other types (int, float, string, bool, etc.), return as-is
|
// For other types (int, float, bool, etc.), return as-is
|
||||||
|
// Check stringers
|
||||||
|
if str, ok := v.(fmt.Stringer); ok {
|
||||||
|
return str.String()
|
||||||
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
)
|
)
|
||||||
@@ -46,6 +47,10 @@ type HookContext struct {
|
|||||||
// User context
|
// User context
|
||||||
UserContext *security.UserContext
|
UserContext *security.UserContext
|
||||||
|
|
||||||
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
|
Tx common.Database
|
||||||
|
|
||||||
// Pagination and filtering (for list queries)
|
// Pagination and filtering (for list queries)
|
||||||
SortColumns string
|
SortColumns string
|
||||||
Limit int
|
Limit int
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ A pluggable metrics collection system with Prometheus implementation.
|
|||||||
```go
|
```go
|
||||||
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
|
||||||
// Initialize Prometheus provider
|
// Initialize Prometheus provider with default config
|
||||||
provider := metrics.NewPrometheusProvider()
|
provider := metrics.NewPrometheusProvider(nil)
|
||||||
metrics.SetProvider(provider)
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
// Apply middleware to your router
|
// Apply middleware to your router
|
||||||
@@ -18,6 +18,59 @@ router.Use(provider.Middleware)
|
|||||||
http.Handle("/metrics", provider.Handler())
|
http.Handle("/metrics", provider.Handler())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
You can customize the metrics provider using a configuration struct:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
|
||||||
|
// Create custom configuration
|
||||||
|
config := &metrics.Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
Namespace: "myapp", // Prefix all metrics with "myapp_"
|
||||||
|
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
|
||||||
|
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize with custom config
|
||||||
|
provider := metrics.NewPrometheusProvider(config)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Options
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `Enabled` | `bool` | `true` | Enable/disable metrics collection |
|
||||||
|
| `Provider` | `string` | `"prometheus"` | Metrics provider type |
|
||||||
|
| `Namespace` | `string` | `""` | Prefix for all metric names |
|
||||||
|
| `HTTPRequestBuckets` | `[]float64` | See below | Histogram buckets for HTTP duration (seconds) |
|
||||||
|
| `DBQueryBuckets` | `[]float64` | See below | Histogram buckets for DB query duration (seconds) |
|
||||||
|
|
||||||
|
**Default HTTP Request Buckets:** `[0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]`
|
||||||
|
|
||||||
|
**Default DB Query Buckets:** `[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]`
|
||||||
|
|
||||||
|
### Pushgateway Configuration (Optional)
|
||||||
|
|
||||||
|
For batch jobs, cron tasks, or short-lived processes, you can push metrics to Prometheus Pushgateway:
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `PushgatewayURL` | `string` | `""` | URL of Pushgateway (e.g., "http://pushgateway:9091") |
|
||||||
|
| `PushgatewayJobName` | `string` | `"resolvespec"` | Job name for pushed metrics |
|
||||||
|
| `PushgatewayInterval` | `int` | `0` | Auto-push interval in seconds (0 = disabled) |
|
||||||
|
|
||||||
|
```go
|
||||||
|
config := &metrics.Config{
|
||||||
|
PushgatewayURL: "http://pushgateway:9091",
|
||||||
|
PushgatewayJobName: "batch-job",
|
||||||
|
PushgatewayInterval: 30, // Push every 30 seconds
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Provider Interface
|
## Provider Interface
|
||||||
|
|
||||||
The package uses a provider interface, allowing you to plug in different metric systems:
|
The package uses a provider interface, allowing you to plug in different metric systems:
|
||||||
@@ -87,6 +140,13 @@ When using `PrometheusProvider`, the following metrics are available:
|
|||||||
| `cache_hits_total` | Counter | provider | Total cache hits |
|
| `cache_hits_total` | Counter | provider | Total cache hits |
|
||||||
| `cache_misses_total` | Counter | provider | Total cache misses |
|
| `cache_misses_total` | Counter | provider | Total cache misses |
|
||||||
| `cache_size_items` | Gauge | provider | Current cache size |
|
| `cache_size_items` | Gauge | provider | Current cache size |
|
||||||
|
| `events_published_total` | Counter | source, event_type | Total events published |
|
||||||
|
| `events_processed_total` | Counter | source, event_type, status | Total events processed |
|
||||||
|
| `event_processing_duration_seconds` | Histogram | source, event_type | Event processing duration |
|
||||||
|
| `event_queue_size` | Gauge | - | Current event queue size |
|
||||||
|
| `panics_total` | Counter | method | Total panics recovered |
|
||||||
|
|
||||||
|
**Note:** If a custom `Namespace` is configured, all metric names will be prefixed with `{namespace}_`.
|
||||||
|
|
||||||
## Prometheus Queries
|
## Prometheus Queries
|
||||||
|
|
||||||
@@ -146,8 +206,126 @@ func (c *CustomProvider) Handler() http.Handler {
|
|||||||
metrics.SetProvider(&CustomProvider{})
|
metrics.SetProvider(&CustomProvider{})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Pushgateway Usage
|
||||||
|
|
||||||
|
### Automatic Push (Batch Jobs)
|
||||||
|
|
||||||
|
For jobs that run periodically, use automatic pushing:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Configure with automatic pushing every 30 seconds
|
||||||
|
config := &metrics.Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
Namespace: "batch_job",
|
||||||
|
PushgatewayURL: "http://pushgateway:9091",
|
||||||
|
PushgatewayJobName: "data-processor",
|
||||||
|
PushgatewayInterval: 30, // Push every 30 seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := metrics.NewPrometheusProvider(config)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
|
// Ensure cleanup on exit
|
||||||
|
defer provider.StopAutoPush()
|
||||||
|
|
||||||
|
// Your batch job logic here
|
||||||
|
processBatchData()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Push (Short-lived Processes)
|
||||||
|
|
||||||
|
For one-time jobs or when you want manual control:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Configure without automatic pushing
|
||||||
|
config := &metrics.Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
PushgatewayURL: "http://pushgateway:9091",
|
||||||
|
PushgatewayJobName: "migration-job",
|
||||||
|
// PushgatewayInterval: 0 (default - no auto-push)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := metrics.NewPrometheusProvider(config)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
|
// Run your job
|
||||||
|
err := runMigration()
|
||||||
|
|
||||||
|
// Push metrics at the end
|
||||||
|
if pushErr := provider.Push(); pushErr != nil {
|
||||||
|
log.Printf("Failed to push metrics: %v", pushErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker Compose with Pushgateway
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '3'
|
||||||
|
services:
|
||||||
|
batch-job:
|
||||||
|
build: .
|
||||||
|
environment:
|
||||||
|
PUSHGATEWAY_URL: "http://pushgateway:9091"
|
||||||
|
|
||||||
|
pushgateway:
|
||||||
|
image: prom/pushgateway
|
||||||
|
ports:
|
||||||
|
- "9091:9091"
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus
|
||||||
|
ports:
|
||||||
|
- "9090:9090"
|
||||||
|
volumes:
|
||||||
|
- ./prometheus.yml:/etc/prometheus/prometheus.yml
|
||||||
|
command:
|
||||||
|
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||||
|
```
|
||||||
|
|
||||||
|
**prometheus.yml for Pushgateway:**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
|
||||||
|
scrape_configs:
|
||||||
|
# Scrape the pushgateway
|
||||||
|
- job_name: 'pushgateway'
|
||||||
|
honor_labels: true # Important: preserve job labels from pushed metrics
|
||||||
|
static_configs:
|
||||||
|
- targets: ['pushgateway:9091']
|
||||||
|
```
|
||||||
|
|
||||||
## Complete Example
|
## Complete Example
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
|
||||||
@@ -162,8 +340,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Initialize metrics
|
// Initialize metrics with default config
|
||||||
provider := metrics.NewPrometheusProvider()
|
provider := metrics.NewPrometheusProvider(nil)
|
||||||
metrics.SetProvider(provider)
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
@@ -198,6 +376,42 @@ func getUsersHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### With Custom Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Custom metrics configuration
|
||||||
|
metricsConfig := &metrics.Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
Namespace: "myapp",
|
||||||
|
// Custom buckets optimized for your application
|
||||||
|
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5, 10},
|
||||||
|
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize with custom config
|
||||||
|
provider := metrics.NewPrometheusProvider(metricsConfig)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.Use(provider.Middleware)
|
||||||
|
router.Handle("/metrics", provider.Handler())
|
||||||
|
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", router))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Docker Compose Example
|
## Docker Compose Example
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -257,3 +471,8 @@ scrape_configs:
|
|||||||
4. **Performance**: Metrics collection is lock-free and highly performant
|
4. **Performance**: Metrics collection is lock-free and highly performant
|
||||||
- Safe for high-throughput applications
|
- Safe for high-throughput applications
|
||||||
- Minimal overhead (<1% in most cases)
|
- Minimal overhead (<1% in most cases)
|
||||||
|
|
||||||
|
5. **Pull vs Push**:
|
||||||
|
- **Use Pull (default)**: Long-running services, web servers, microservices
|
||||||
|
- **Use Push (Pushgateway)**: Batch jobs, cron tasks, short-lived processes, serverless functions
|
||||||
|
- Pull is preferred for most applications as it allows Prometheus to detect if your service is down
|
||||||
|
|||||||
64
pkg/metrics/config.go
Normal file
64
pkg/metrics/config.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
// Config holds configuration for the metrics provider
|
||||||
|
type Config struct {
|
||||||
|
// Enabled determines whether metrics collection is enabled
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
|
||||||
|
// Provider specifies which metrics provider to use (prometheus, noop)
|
||||||
|
Provider string `mapstructure:"provider"`
|
||||||
|
|
||||||
|
// Namespace is an optional prefix for all metric names
|
||||||
|
Namespace string `mapstructure:"namespace"`
|
||||||
|
|
||||||
|
// HTTPRequestBuckets defines histogram buckets for HTTP request duration (in seconds)
|
||||||
|
// Default: [0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10]
|
||||||
|
HTTPRequestBuckets []float64 `mapstructure:"http_request_buckets"`
|
||||||
|
|
||||||
|
// DBQueryBuckets defines histogram buckets for database query duration (in seconds)
|
||||||
|
// Default: [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5]
|
||||||
|
DBQueryBuckets []float64 `mapstructure:"db_query_buckets"`
|
||||||
|
|
||||||
|
// PushgatewayURL is the URL of the Prometheus Pushgateway (optional)
|
||||||
|
// If set, metrics will be pushed to this gateway instead of only being scraped
|
||||||
|
// Example: "http://pushgateway:9091"
|
||||||
|
PushgatewayURL string `mapstructure:"pushgateway_url"`
|
||||||
|
|
||||||
|
// PushgatewayJobName is the job name to use when pushing metrics to Pushgateway
|
||||||
|
// Default: "resolvespec"
|
||||||
|
PushgatewayJobName string `mapstructure:"pushgateway_job_name"`
|
||||||
|
|
||||||
|
// PushgatewayInterval is the interval at which to push metrics to Pushgateway
|
||||||
|
// Only used if PushgatewayURL is set. If 0, automatic pushing is disabled.
|
||||||
|
// Default: 0 (no automatic pushing)
|
||||||
|
PushgatewayInterval int `mapstructure:"pushgateway_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a Config with sensible defaults
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
// HTTP requests typically take longer than DB queries
|
||||||
|
HTTPRequestBuckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||||
|
// DB queries are usually faster
|
||||||
|
DBQueryBuckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaults fills in any missing values with defaults
|
||||||
|
func (c *Config) ApplyDefaults() {
|
||||||
|
if c.Provider == "" {
|
||||||
|
c.Provider = "prometheus"
|
||||||
|
}
|
||||||
|
if len(c.HTTPRequestBuckets) == 0 {
|
||||||
|
c.HTTPRequestBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
|
||||||
|
}
|
||||||
|
if len(c.DBQueryBuckets) == 0 {
|
||||||
|
c.DBQueryBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5}
|
||||||
|
}
|
||||||
|
// Set default job name if pushgateway is configured but job name is empty
|
||||||
|
if c.PushgatewayURL != "" && c.PushgatewayJobName == "" {
|
||||||
|
c.PushgatewayJobName = "resolvespec"
|
||||||
|
}
|
||||||
|
}
|
||||||
64
pkg/metrics/example_test.go
Normal file
64
pkg/metrics/example_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package metrics_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExampleNewPrometheusProvider_default demonstrates using default configuration
|
||||||
|
func ExampleNewPrometheusProvider_default() {
|
||||||
|
// Initialize with default configuration
|
||||||
|
provider := metrics.NewPrometheusProvider(nil)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
|
fmt.Println("Provider initialized with defaults")
|
||||||
|
// Output: Provider initialized with defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleNewPrometheusProvider_custom demonstrates using custom configuration
|
||||||
|
func ExampleNewPrometheusProvider_custom() {
|
||||||
|
// Create custom configuration
|
||||||
|
config := &metrics.Config{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "prometheus",
|
||||||
|
Namespace: "myapp",
|
||||||
|
HTTPRequestBuckets: []float64{0.01, 0.05, 0.1, 0.5, 1, 2, 5},
|
||||||
|
DBQueryBuckets: []float64{0.001, 0.01, 0.05, 0.1, 0.5, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize with custom configuration
|
||||||
|
provider := metrics.NewPrometheusProvider(config)
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
|
||||||
|
fmt.Println("Provider initialized with custom config")
|
||||||
|
// Output: Provider initialized with custom config
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleDefaultConfig demonstrates getting default configuration
|
||||||
|
func ExampleDefaultConfig() {
|
||||||
|
config := metrics.DefaultConfig()
|
||||||
|
fmt.Printf("Default provider: %s\n", config.Provider)
|
||||||
|
fmt.Printf("Default enabled: %v\n", config.Enabled)
|
||||||
|
// Output:
|
||||||
|
// Default provider: prometheus
|
||||||
|
// Default enabled: true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleConfig_ApplyDefaults demonstrates applying defaults to partial config
|
||||||
|
func ExampleConfig_ApplyDefaults() {
|
||||||
|
// Create partial configuration
|
||||||
|
config := &metrics.Config{
|
||||||
|
Namespace: "myapp",
|
||||||
|
// Other fields will be filled with defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
config.ApplyDefaults()
|
||||||
|
|
||||||
|
fmt.Printf("Provider: %s\n", config.Provider)
|
||||||
|
fmt.Printf("Namespace: %s\n", config.Namespace)
|
||||||
|
// Output:
|
||||||
|
// Provider: prometheus
|
||||||
|
// Namespace: myapp
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/push"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrometheusProvider implements the Provider interface using Prometheus
|
// PrometheusProvider implements the Provider interface using Prometheus
|
||||||
@@ -20,23 +21,51 @@ type PrometheusProvider struct {
|
|||||||
cacheHits *prometheus.CounterVec
|
cacheHits *prometheus.CounterVec
|
||||||
cacheMisses *prometheus.CounterVec
|
cacheMisses *prometheus.CounterVec
|
||||||
cacheSize *prometheus.GaugeVec
|
cacheSize *prometheus.GaugeVec
|
||||||
|
eventPublished *prometheus.CounterVec
|
||||||
|
eventProcessed *prometheus.CounterVec
|
||||||
|
eventDuration *prometheus.HistogramVec
|
||||||
|
eventQueueSize prometheus.Gauge
|
||||||
panicsTotal *prometheus.CounterVec
|
panicsTotal *prometheus.CounterVec
|
||||||
|
|
||||||
|
// Pushgateway fields (optional)
|
||||||
|
pushgatewayURL string
|
||||||
|
pushgatewayJobName string
|
||||||
|
pusher *push.Pusher
|
||||||
|
pushTicker *time.Ticker
|
||||||
|
pushStop chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||||
func NewPrometheusProvider() *PrometheusProvider {
|
// If cfg is nil, default configuration will be used
|
||||||
return &PrometheusProvider{
|
func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
|
||||||
|
// Use default config if none provided
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = DefaultConfig()
|
||||||
|
} else {
|
||||||
|
// Apply defaults for any missing values
|
||||||
|
cfg.ApplyDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to add namespace prefix if configured
|
||||||
|
metricName := func(name string) string {
|
||||||
|
if cfg.Namespace != "" {
|
||||||
|
return cfg.Namespace + "_" + name
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &PrometheusProvider{
|
||||||
requestDuration: promauto.NewHistogramVec(
|
requestDuration: promauto.NewHistogramVec(
|
||||||
prometheus.HistogramOpts{
|
prometheus.HistogramOpts{
|
||||||
Name: "http_request_duration_seconds",
|
Name: metricName("http_request_duration_seconds"),
|
||||||
Help: "HTTP request duration in seconds",
|
Help: "HTTP request duration in seconds",
|
||||||
Buckets: prometheus.DefBuckets,
|
Buckets: cfg.HTTPRequestBuckets,
|
||||||
},
|
},
|
||||||
[]string{"method", "path", "status"},
|
[]string{"method", "path", "status"},
|
||||||
),
|
),
|
||||||
requestTotal: promauto.NewCounterVec(
|
requestTotal: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "http_requests_total",
|
Name: metricName("http_requests_total"),
|
||||||
Help: "Total number of HTTP requests",
|
Help: "Total number of HTTP requests",
|
||||||
},
|
},
|
||||||
[]string{"method", "path", "status"},
|
[]string{"method", "path", "status"},
|
||||||
@@ -44,54 +73,100 @@ func NewPrometheusProvider() *PrometheusProvider {
|
|||||||
|
|
||||||
requestsInFlight: promauto.NewGauge(
|
requestsInFlight: promauto.NewGauge(
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Name: "http_requests_in_flight",
|
Name: metricName("http_requests_in_flight"),
|
||||||
Help: "Current number of HTTP requests being processed",
|
Help: "Current number of HTTP requests being processed",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
dbQueryDuration: promauto.NewHistogramVec(
|
dbQueryDuration: promauto.NewHistogramVec(
|
||||||
prometheus.HistogramOpts{
|
prometheus.HistogramOpts{
|
||||||
Name: "db_query_duration_seconds",
|
Name: metricName("db_query_duration_seconds"),
|
||||||
Help: "Database query duration in seconds",
|
Help: "Database query duration in seconds",
|
||||||
Buckets: prometheus.DefBuckets,
|
Buckets: cfg.DBQueryBuckets,
|
||||||
},
|
},
|
||||||
[]string{"operation", "table"},
|
[]string{"operation", "table"},
|
||||||
),
|
),
|
||||||
dbQueryTotal: promauto.NewCounterVec(
|
dbQueryTotal: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "db_queries_total",
|
Name: metricName("db_queries_total"),
|
||||||
Help: "Total number of database queries",
|
Help: "Total number of database queries",
|
||||||
},
|
},
|
||||||
[]string{"operation", "table", "status"},
|
[]string{"operation", "table", "status"},
|
||||||
),
|
),
|
||||||
cacheHits: promauto.NewCounterVec(
|
cacheHits: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "cache_hits_total",
|
Name: metricName("cache_hits_total"),
|
||||||
Help: "Total number of cache hits",
|
Help: "Total number of cache hits",
|
||||||
},
|
},
|
||||||
[]string{"provider"},
|
[]string{"provider"},
|
||||||
),
|
),
|
||||||
cacheMisses: promauto.NewCounterVec(
|
cacheMisses: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "cache_misses_total",
|
Name: metricName("cache_misses_total"),
|
||||||
Help: "Total number of cache misses",
|
Help: "Total number of cache misses",
|
||||||
},
|
},
|
||||||
[]string{"provider"},
|
[]string{"provider"},
|
||||||
),
|
),
|
||||||
cacheSize: promauto.NewGaugeVec(
|
cacheSize: promauto.NewGaugeVec(
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Name: "cache_size_items",
|
Name: metricName("cache_size_items"),
|
||||||
Help: "Number of items in cache",
|
Help: "Number of items in cache",
|
||||||
},
|
},
|
||||||
[]string{"provider"},
|
[]string{"provider"},
|
||||||
),
|
),
|
||||||
|
eventPublished: promauto.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: metricName("events_published_total"),
|
||||||
|
Help: "Total number of events published",
|
||||||
|
},
|
||||||
|
[]string{"source", "event_type"},
|
||||||
|
),
|
||||||
|
eventProcessed: promauto.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: metricName("events_processed_total"),
|
||||||
|
Help: "Total number of events processed",
|
||||||
|
},
|
||||||
|
[]string{"source", "event_type", "status"},
|
||||||
|
),
|
||||||
|
eventDuration: promauto.NewHistogramVec(
|
||||||
|
prometheus.HistogramOpts{
|
||||||
|
Name: metricName("event_processing_duration_seconds"),
|
||||||
|
Help: "Event processing duration in seconds",
|
||||||
|
Buckets: cfg.DBQueryBuckets, // Events are typically fast like DB queries
|
||||||
|
},
|
||||||
|
[]string{"source", "event_type"},
|
||||||
|
),
|
||||||
|
eventQueueSize: promauto.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: metricName("event_queue_size"),
|
||||||
|
Help: "Current number of events in queue",
|
||||||
|
},
|
||||||
|
),
|
||||||
panicsTotal: promauto.NewCounterVec(
|
panicsTotal: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "panics_total",
|
Name: metricName("panics_total"),
|
||||||
Help: "Total number of panics",
|
Help: "Total number of panics",
|
||||||
},
|
},
|
||||||
[]string{"method"},
|
[]string{"method"},
|
||||||
),
|
),
|
||||||
|
|
||||||
|
pushgatewayURL: cfg.PushgatewayURL,
|
||||||
|
pushgatewayJobName: cfg.PushgatewayJobName,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize pushgateway if configured
|
||||||
|
if cfg.PushgatewayURL != "" {
|
||||||
|
p.pusher = push.New(cfg.PushgatewayURL, cfg.PushgatewayJobName).
|
||||||
|
Gatherer(prometheus.DefaultGatherer)
|
||||||
|
|
||||||
|
// Start automatic pushing if interval is configured
|
||||||
|
if cfg.PushgatewayInterval > 0 {
|
||||||
|
p.pushStop = make(chan bool)
|
||||||
|
p.pushTicker = time.NewTicker(time.Duration(cfg.PushgatewayInterval) * time.Second)
|
||||||
|
go p.startAutoPush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriter wraps http.ResponseWriter to capture status code
|
// ResponseWriter wraps http.ResponseWriter to capture status code
|
||||||
@@ -153,6 +228,22 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
|||||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordEventPublished implements Provider interface
|
||||||
|
func (p *PrometheusProvider) RecordEventPublished(source, eventType string) {
|
||||||
|
p.eventPublished.WithLabelValues(source, eventType).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordEventProcessed implements Provider interface
|
||||||
|
func (p *PrometheusProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||||
|
p.eventProcessed.WithLabelValues(source, eventType, status).Inc()
|
||||||
|
p.eventDuration.WithLabelValues(source, eventType).Observe(duration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateEventQueueSize implements Provider interface
|
||||||
|
func (p *PrometheusProvider) UpdateEventQueueSize(size int64) {
|
||||||
|
p.eventQueueSize.Set(float64(size))
|
||||||
|
}
|
||||||
|
|
||||||
// RecordPanic implements the Provider interface
|
// RecordPanic implements the Provider interface
|
||||||
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||||
p.panicsTotal.WithLabelValues(methodName).Inc()
|
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||||
@@ -185,3 +276,37 @@ func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
|
|||||||
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
|
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Push manually pushes metrics to the configured Pushgateway
|
||||||
|
// Returns an error if pushing fails or if Pushgateway is not configured
|
||||||
|
func (p *PrometheusProvider) Push() error {
|
||||||
|
if p.pusher == nil {
|
||||||
|
return nil // Pushgateway not configured, silently skip
|
||||||
|
}
|
||||||
|
return p.pusher.Push()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startAutoPush runs in a goroutine and periodically pushes metrics to Pushgateway
|
||||||
|
func (p *PrometheusProvider) startAutoPush() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.pushTicker.C:
|
||||||
|
if err := p.Push(); err != nil {
|
||||||
|
// Log error but continue pushing
|
||||||
|
// Note: In production, you might want to use a proper logger
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
case <-p.pushStop:
|
||||||
|
p.pushTicker.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopAutoPush stops the automatic push goroutine
|
||||||
|
// This should be called when shutting down the application
|
||||||
|
func (p *PrometheusProvider) StopAutoPush() {
|
||||||
|
if p.pushStop != nil {
|
||||||
|
close(p.pushStop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package reflection
|
package reflection
|
||||||
|
|
||||||
import "reflect"
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
func Len(v any) int {
|
func Len(v any) int {
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
@@ -47,3 +50,58 @@ func ExtractTableNameOnly(fullName string) string {
|
|||||||
|
|
||||||
return fullName[startIndex:]
|
return fullName[startIndex:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPointerElement returns the element type if the provided reflect.Type is a pointer.
|
||||||
|
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
||||||
|
// If neither condition is met, it returns the original type.
|
||||||
|
func GetPointerElement(v reflect.Type) reflect.Type {
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
return v.Elem()
|
||||||
|
}
|
||||||
|
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr {
|
||||||
|
subElem := v.Elem()
|
||||||
|
if subElem.Elem().Kind() == reflect.Ptr {
|
||||||
|
return subElem.Elem().Elem()
|
||||||
|
}
|
||||||
|
return v.Elem()
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJSONNameForField gets the JSON tag name for a struct field.
|
||||||
|
// Returns the JSON field name from the json struct tag, or an empty string if not found.
|
||||||
|
// Handles the "json" tag format: "name", "name,omitempty", etc.
|
||||||
|
func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
||||||
|
if modelType == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointer types
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field
|
||||||
|
field, found := modelType.FieldByName(fieldName)
|
||||||
|
if !found {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the tag (format: "name,omitempty" or just "name")
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -584,12 +584,24 @@ func ExtractSourceColumn(colName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ToSnakeCase converts a string from CamelCase to snake_case
|
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||||
|
// Handles consecutive uppercase letters (acronyms) correctly:
|
||||||
|
// "HTTPServer" -> "http_server", "UserID" -> "user_id", "MyHTTPServer" -> "my_http_server"
|
||||||
func ToSnakeCase(s string) string {
|
func ToSnakeCase(s string) string {
|
||||||
var result strings.Builder
|
var result strings.Builder
|
||||||
for i, r := range s {
|
runes := []rune(s)
|
||||||
|
|
||||||
|
for i, r := range runes {
|
||||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||||
|
// Add underscore if:
|
||||||
|
// 1. Previous character is lowercase, OR
|
||||||
|
// 2. Next character is lowercase (transition from acronym to word)
|
||||||
|
prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z'
|
||||||
|
nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z'
|
||||||
|
|
||||||
|
if prevIsLower || nextIsLower {
|
||||||
result.WriteRune('_')
|
result.WriteRune('_')
|
||||||
}
|
}
|
||||||
|
}
|
||||||
result.WriteRune(r)
|
result.WriteRune(r)
|
||||||
}
|
}
|
||||||
return strings.ToLower(result.String())
|
return strings.ToLower(result.String())
|
||||||
@@ -936,32 +948,38 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
|||||||
// Build list of possible column names for this field
|
// Build list of possible column names for this field
|
||||||
var columnNames []string
|
var columnNames []string
|
||||||
|
|
||||||
// 1. Bun tag
|
// 1. JSON tag (primary - most common)
|
||||||
|
jsonFound := false
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
columnNames = append(columnNames, parts[0])
|
||||||
|
jsonFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Bun tag (fallback if no JSON tag)
|
||||||
|
if !jsonFound {
|
||||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||||
columnNames = append(columnNames, colName)
|
columnNames = append(columnNames, colName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 2. Gorm tag
|
// 3. Gorm tag (fallback if no JSON tag)
|
||||||
|
if !jsonFound {
|
||||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||||
columnNames = append(columnNames, colName)
|
columnNames = append(columnNames, colName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. JSON tag
|
|
||||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
|
||||||
parts := strings.Split(jsonTag, ",")
|
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
|
||||||
columnNames = append(columnNames, parts[0])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Field name variations
|
// 4. Field name variations (last resort)
|
||||||
columnNames = append(columnNames, field.Name)
|
columnNames = append(columnNames, field.Name)
|
||||||
columnNames = append(columnNames, strings.ToLower(field.Name))
|
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||||
columnNames = append(columnNames, ToSnakeCase(field.Name))
|
// columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||||
|
|
||||||
// Map all column name variations to this field index
|
// Map all column name variations to this field index
|
||||||
for _, colName := range columnNames {
|
for _, colName := range columnNames {
|
||||||
@@ -1067,7 +1085,7 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
case string:
|
case string:
|
||||||
field.SetBytes([]byte(v))
|
field.SetBytes([]byte(v))
|
||||||
return nil
|
return nil
|
||||||
case map[string]interface{}, []interface{}:
|
case map[string]interface{}, []interface{}, []*any, map[string]*any:
|
||||||
// Marshal complex types to JSON for SqlJSONB fields
|
// Marshal complex types to JSON for SqlJSONB fields
|
||||||
jsonBytes, err := json.Marshal(v)
|
jsonBytes, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1077,6 +1095,17 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle slice-to-slice conversions (e.g., []interface{} to []*SomeModel)
|
||||||
|
if valueReflect.Kind() == reflect.Slice {
|
||||||
|
return convertSlice(field, valueReflect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can convert the type, do it
|
||||||
|
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||||
|
field.Set(valueReflect.Convert(field.Type()))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||||
@@ -1090,9 +1119,9 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
// Call the Scan method with the value
|
// Call the Scan method with the value
|
||||||
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
||||||
if len(results) > 0 {
|
if len(results) > 0 {
|
||||||
// Check if there was an error
|
// The Scan method returns error - check if it's nil
|
||||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
if !results[0].IsNil() {
|
||||||
return err
|
return results[0].Interface().(error)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1147,13 +1176,93 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we can convert the type, do it
|
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
}
|
||||||
field.Set(valueReflect.Convert(field.Type()))
|
|
||||||
return nil
|
// convertSlice converts a source slice to a target slice type, handling element-wise conversions
|
||||||
|
// Supports converting []interface{} to slices of structs or pointers to structs
|
||||||
|
func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error {
|
||||||
|
if sourceSlice.Kind() != reflect.Slice || targetSlice.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("both source and target must be slices")
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
// Get the element type of the target slice
|
||||||
|
targetElemType := targetSlice.Type().Elem()
|
||||||
|
sourceLen := sourceSlice.Len()
|
||||||
|
|
||||||
|
// Create a new slice with the same length as the source
|
||||||
|
newSlice := reflect.MakeSlice(targetSlice.Type(), sourceLen, sourceLen)
|
||||||
|
|
||||||
|
// Convert each element
|
||||||
|
for i := 0; i < sourceLen; i++ {
|
||||||
|
sourceElem := sourceSlice.Index(i)
|
||||||
|
targetElem := newSlice.Index(i)
|
||||||
|
|
||||||
|
// Get the actual value from the source element
|
||||||
|
var sourceValue interface{}
|
||||||
|
if sourceElem.CanInterface() {
|
||||||
|
sourceValue = sourceElem.Interface()
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle nil elements
|
||||||
|
if sourceValue == nil {
|
||||||
|
// For pointer types, nil is valid
|
||||||
|
if targetElemType.Kind() == reflect.Ptr {
|
||||||
|
targetElem.Set(reflect.Zero(targetElemType))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If target element type is a pointer to struct, we need to create new instances
|
||||||
|
if targetElemType.Kind() == reflect.Ptr {
|
||||||
|
// Create a new instance of the pointed-to type
|
||||||
|
newElemPtr := reflect.New(targetElemType.Elem())
|
||||||
|
|
||||||
|
// Convert the source value to the struct
|
||||||
|
switch sv := sourceValue.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Source is a map, use MapToStruct to populate the new instance
|
||||||
|
if err := MapToStruct(sv, newElemPtr.Interface()); err != nil {
|
||||||
|
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Try direct conversion or setFieldValue
|
||||||
|
if err := setFieldValue(newElemPtr.Elem(), sourceValue); err != nil {
|
||||||
|
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElem.Set(newElemPtr)
|
||||||
|
} else if targetElemType.Kind() == reflect.Struct {
|
||||||
|
// Target element is a struct (not a pointer)
|
||||||
|
switch sv := sourceValue.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Use MapToStruct to populate the element
|
||||||
|
elemPtr := targetElem.Addr()
|
||||||
|
if elemPtr.CanInterface() {
|
||||||
|
if err := MapToStruct(sv, elemPtr.Interface()); err != nil {
|
||||||
|
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Try direct conversion
|
||||||
|
if err := setFieldValue(targetElem, sourceValue); err != nil {
|
||||||
|
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For other types, use setFieldValue
|
||||||
|
if err := setFieldValue(targetElem, sourceValue); err != nil {
|
||||||
|
return fmt.Errorf("failed to convert element %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the converted slice to the target field
|
||||||
|
targetSlice.Set(newSlice)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToInt64 attempts to convert various types to int64
|
// convertToInt64 attempts to convert various types to int64
|
||||||
@@ -1261,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
|
||||||
|
// This can be used to validate input data against a model's structure
|
||||||
|
// The map keys are the JSON field names (from json tags) that exist in the model
|
||||||
|
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||||
|
validFields := make(map[string]bool)
|
||||||
|
|
||||||
|
// Unwrap pointers to get to the base struct type
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return validFields
|
||||||
|
}
|
||||||
|
|
||||||
|
collectValidFieldNames(modelType, validFields)
|
||||||
|
return validFields
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectValidFieldNames recursively collects valid JSON field names from a struct type
|
||||||
|
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Skip unexported fields
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for embedded structs
|
||||||
|
if field.Anonymous {
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
// Recursively add fields from embedded struct
|
||||||
|
collectValidFieldNames(fieldType, validFields)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the JSON tag name for this field (same logic as MapToStruct)
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Extract the field name from the JSON tag (before any options like omitempty)
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
validFields[parts[0]] = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If no JSON tag, use the field name in lowercase as a fallback
|
||||||
|
validFields[strings.ToLower(field.Name)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||||
|
|||||||
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package reflection_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMapToStruct_StandardSqlNullTypes(t *testing.T) {
|
||||||
|
// Test model with standard library sql.Null* types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||||
|
Name sql.NullString `bun:"name" json:"name"`
|
||||||
|
Score sql.NullFloat64 `bun:"score" json:"score"`
|
||||||
|
Active sql.NullBool `bun:"active" json:"active"`
|
||||||
|
UpdatedAt sql.NullTime `bun:"updated_at" json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
dataMap := map[string]any{
|
||||||
|
"id": int64(100),
|
||||||
|
"age": int64(25),
|
||||||
|
"name": "John Doe",
|
||||||
|
"score": 95.5,
|
||||||
|
"active": true,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify ID
|
||||||
|
if result.ID != 100 {
|
||||||
|
t.Errorf("ID = %v, want 100", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Age (sql.NullInt64)
|
||||||
|
if !result.Age.Valid {
|
||||||
|
t.Error("Age.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if result.Age.Int64 != 25 {
|
||||||
|
t.Errorf("Age.Int64 = %v, want 25", result.Age.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Name (sql.NullString)
|
||||||
|
if !result.Name.Valid {
|
||||||
|
t.Error("Name.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if result.Name.String != "John Doe" {
|
||||||
|
t.Errorf("Name.String = %v, want 'John Doe'", result.Name.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Score (sql.NullFloat64)
|
||||||
|
if !result.Score.Valid {
|
||||||
|
t.Error("Score.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if result.Score.Float64 != 95.5 {
|
||||||
|
t.Errorf("Score.Float64 = %v, want 95.5", result.Score.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Active (sql.NullBool)
|
||||||
|
if !result.Active.Valid {
|
||||||
|
t.Error("Active.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.Active.Bool {
|
||||||
|
t.Error("Active.Bool = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify UpdatedAt (sql.NullTime)
|
||||||
|
if !result.UpdatedAt.Valid {
|
||||||
|
t.Error("UpdatedAt.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.UpdatedAt.Time.Equal(now) {
|
||||||
|
t.Errorf("UpdatedAt.Time = %v, want %v", result.UpdatedAt.Time, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("All standard library sql.Null* types handled correctly!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_StandardSqlNullTypes_WithNil(t *testing.T) {
|
||||||
|
// Test nil handling for standard library sql.Null* types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||||
|
Name sql.NullString `bun:"name" json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]any{
|
||||||
|
"id": int64(200),
|
||||||
|
"age": int64(30),
|
||||||
|
"name": nil, // Explicitly nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Age should be valid
|
||||||
|
if !result.Age.Valid {
|
||||||
|
t.Error("Age.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if result.Age.Int64 != 30 {
|
||||||
|
t.Errorf("Age.Int64 = %v, want 30", result.Age.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name should be invalid (null)
|
||||||
|
if result.Name.Valid {
|
||||||
|
t.Error("Name.Valid = true, want false (null)")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Nil handling for sql.Null* types works correctly!")
|
||||||
|
}
|
||||||
364
pkg/reflection/spectypes_integration_test.go
Normal file
364
pkg/reflection/spectypes_integration_test.go
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
package reflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestModel contains all spectypes custom types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Name spectypes.SqlString `bun:"name" json:"name"`
|
||||||
|
Age spectypes.SqlInt64 `bun:"age" json:"age"`
|
||||||
|
Score spectypes.SqlFloat64 `bun:"score" json:"score"`
|
||||||
|
Active spectypes.SqlBool `bun:"active" json:"active"`
|
||||||
|
UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"`
|
||||||
|
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||||
|
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||||
|
StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"`
|
||||||
|
Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"`
|
||||||
|
Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"`
|
||||||
|
Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly
|
||||||
|
func TestMapToStruct_AllSpectypes(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
testTime := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataMap map[string]interface{}
|
||||||
|
validator func(*testing.T, *TestModel)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SqlString from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Name.Valid || m.Name.String() != "John Doe" {
|
||||||
|
t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlInt64 from int64",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"age": int64(42),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Age.Valid || m.Age.Int64() != 42 {
|
||||||
|
t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlInt64 from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"age": "99",
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Age.Valid || m.Age.Int64() != 99 {
|
||||||
|
t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlFloat64 from float64",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"score": float64(98.5),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Score.Valid || m.Score.Float64() != 98.5 {
|
||||||
|
t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlBool from bool",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Active.Valid || !m.Active.Bool() {
|
||||||
|
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlUUID from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"uuid": testUUID.String(),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.UUID.Valid || m.UUID.UUID() != testUUID {
|
||||||
|
t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlTimeStamp from time.Time",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"created_at": testTime,
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.CreatedAt.Valid {
|
||||||
|
t.Errorf("expected created_at to be valid")
|
||||||
|
}
|
||||||
|
// Check if times are close enough (within a second)
|
||||||
|
diff := m.CreatedAt.Time().Sub(testTime)
|
||||||
|
if diff < -time.Second || diff > time.Second {
|
||||||
|
t.Errorf("time difference too large: %v", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlTimeStamp from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"created_at": "2024-01-15T10:30:00",
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.CreatedAt.Valid {
|
||||||
|
t.Errorf("expected created_at to be valid")
|
||||||
|
}
|
||||||
|
expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
if m.CreatedAt.Time().Year() != expected.Year() ||
|
||||||
|
m.CreatedAt.Time().Month() != expected.Month() ||
|
||||||
|
m.CreatedAt.Time().Day() != expected.Day() {
|
||||||
|
t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlDate from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"birth_date": "2000-05-20",
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.BirthDate.Valid {
|
||||||
|
t.Errorf("expected birth_date to be valid")
|
||||||
|
}
|
||||||
|
expected := "2000-05-20"
|
||||||
|
if m.BirthDate.String() != expected {
|
||||||
|
t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlTime from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"start_time": "14:30:00",
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.StartTime.Valid {
|
||||||
|
t.Errorf("expected start_time to be valid")
|
||||||
|
}
|
||||||
|
if m.StartTime.String() != "14:30:00" {
|
||||||
|
t.Errorf("expected time=14:30:00, got %s", m.StartTime.String())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlJSONB from map",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"metadata": map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": 123,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if len(m.Metadata) == 0 {
|
||||||
|
t.Errorf("expected metadata to have data")
|
||||||
|
}
|
||||||
|
asMap, err := m.Metadata.AsMap()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||||
|
}
|
||||||
|
if asMap["key1"] != "value1" {
|
||||||
|
t.Errorf("expected key1=value1, got %v", asMap["key1"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlJSONB from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"metadata": `{"test":"data"}`,
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if len(m.Metadata) == 0 {
|
||||||
|
t.Errorf("expected metadata to have data")
|
||||||
|
}
|
||||||
|
asMap, err := m.Metadata.AsMap()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||||
|
}
|
||||||
|
if asMap["test"] != "data" {
|
||||||
|
t.Errorf("expected test=data, got %v", asMap["test"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlJSONB from []byte",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"metadata": []byte(`{"byte":"array"}`),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if len(m.Metadata) == 0 {
|
||||||
|
t.Errorf("expected metadata to have data")
|
||||||
|
}
|
||||||
|
if string(m.Metadata) != `{"byte":"array"}` {
|
||||||
|
t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlInt16 from int16",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"count16": int16(100),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Count16.Valid || m.Count16.Int64() != 100 {
|
||||||
|
t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SqlInt32 from int32",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"count32": int32(5000),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if !m.Count32.Valid || m.Count32.Int64() != 5000 {
|
||||||
|
t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil values create invalid nulls",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"name": nil,
|
||||||
|
"age": nil,
|
||||||
|
"active": nil,
|
||||||
|
"created_at": nil,
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if m.Name.Valid {
|
||||||
|
t.Error("expected name to be invalid for nil value")
|
||||||
|
}
|
||||||
|
if m.Age.Valid {
|
||||||
|
t.Error("expected age to be invalid for nil value")
|
||||||
|
}
|
||||||
|
if m.Active.Valid {
|
||||||
|
t.Error("expected active to be invalid for nil value")
|
||||||
|
}
|
||||||
|
if m.CreatedAt.Valid {
|
||||||
|
t.Error("expected created_at to be invalid for nil value")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all types together",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(1),
|
||||||
|
"name": "Test User",
|
||||||
|
"age": int64(30),
|
||||||
|
"score": float64(95.7),
|
||||||
|
"active": true,
|
||||||
|
"uuid": testUUID.String(),
|
||||||
|
"created_at": "2024-01-15T10:30:00",
|
||||||
|
"birth_date": "1994-06-15",
|
||||||
|
"start_time": "09:00:00",
|
||||||
|
"metadata": map[string]interface{}{"role": "admin"},
|
||||||
|
"count16": int16(50),
|
||||||
|
"count32": int32(1000),
|
||||||
|
},
|
||||||
|
validator: func(t *testing.T, m *TestModel) {
|
||||||
|
if m.ID != 1 {
|
||||||
|
t.Errorf("expected id=1, got %d", m.ID)
|
||||||
|
}
|
||||||
|
if !m.Name.Valid || m.Name.String() != "Test User" {
|
||||||
|
t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||||
|
}
|
||||||
|
if !m.Age.Valid || m.Age.Int64() != 30 {
|
||||||
|
t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||||
|
}
|
||||||
|
if !m.Score.Valid || m.Score.Float64() != 95.7 {
|
||||||
|
t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||||
|
}
|
||||||
|
if !m.Active.Valid || !m.Active.Bool() {
|
||||||
|
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||||
|
}
|
||||||
|
if !m.UUID.Valid {
|
||||||
|
t.Error("expected uuid to be valid")
|
||||||
|
}
|
||||||
|
if !m.CreatedAt.Valid {
|
||||||
|
t.Error("expected created_at to be valid")
|
||||||
|
}
|
||||||
|
if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" {
|
||||||
|
t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String())
|
||||||
|
}
|
||||||
|
if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" {
|
||||||
|
t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String())
|
||||||
|
}
|
||||||
|
if len(m.Metadata) == 0 {
|
||||||
|
t.Error("expected metadata to have data")
|
||||||
|
}
|
||||||
|
if !m.Count16.Valid || m.Count16.Int64() != 50 {
|
||||||
|
t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||||
|
}
|
||||||
|
if !m.Count32.Valid || m.Count32.Int64() != 1000 {
|
||||||
|
t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
model := &TestModel{}
|
||||||
|
if err := MapToStruct(tt.dataMap, model); err != nil {
|
||||||
|
t.Fatalf("MapToStruct failed: %v", err)
|
||||||
|
}
|
||||||
|
tt.validator(t, model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields
|
||||||
|
func TestMapToStruct_PartialUpdate(t *testing.T) {
|
||||||
|
// Create initial model with some values
|
||||||
|
initial := &TestModel{
|
||||||
|
ID: 1,
|
||||||
|
Name: spectypes.NewSqlString("Original Name"),
|
||||||
|
Age: spectypes.NewSqlInt64(25),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update only the age field
|
||||||
|
partialUpdate := map[string]interface{}{
|
||||||
|
"age": int64(30),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply partial update
|
||||||
|
if err := MapToStruct(partialUpdate, initial); err != nil {
|
||||||
|
t.Fatalf("MapToStruct failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify age was updated
|
||||||
|
if !initial.Age.Valid || initial.Age.Int64() != 30 {
|
||||||
|
t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify name was preserved (not overwritten with zero value)
|
||||||
|
if !initial.Name.Valid || initial.Name.String() != "Original Name" {
|
||||||
|
t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify ID was preserved
|
||||||
|
if initial.ID != 1 {
|
||||||
|
t.Errorf("expected id=1 to be preserved, got %d", initial.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
703
pkg/resolvespec/README.md
Normal file
703
pkg/resolvespec/README.md
Normal file
@@ -0,0 +1,703 @@
|
|||||||
|
# ResolveSpec - Body-Based REST API
|
||||||
|
|
||||||
|
ResolveSpec provides a REST API where query options are passed in the JSON request body. This approach offers GraphQL-like flexibility while maintaining RESTful principles, making it ideal for complex queries and operations.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* **Body-Based Querying**: All query options passed via JSON request body
|
||||||
|
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||||
|
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||||
|
* **Offset Pagination**: Traditional limit/offset pagination support
|
||||||
|
* **Advanced Filtering**: Multiple operators, AND/OR logic, and custom SQL
|
||||||
|
* **Relationship Preloading**: Load related entities with custom column selection and filters
|
||||||
|
* **Recursive CRUD**: Automatically handle nested object graphs with foreign key resolution
|
||||||
|
* **Computed Columns**: Define virtual columns with SQL expressions
|
||||||
|
* **Database-Agnostic**: Works with GORM, Bun, or custom database adapters
|
||||||
|
* **Router-Agnostic**: Integrates with any HTTP router through standard interfaces
|
||||||
|
* **Type-Safe**: Strong type validation and conversion
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Setup with GORM
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
|
import "github.com/gorilla/mux"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := resolvespec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// IMPORTANT: Register models BEFORE setting up routes
|
||||||
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
|
handler.registry.RegisterModel("core.posts", &Post{})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setup with Bun ORM
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
|
import "github.com/uptrace/bun"
|
||||||
|
|
||||||
|
// Create handler with Bun
|
||||||
|
handler := resolvespec.NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
|
|
||||||
|
// Setup routes (same as GORM)
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Read Request
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /core/users HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Preloading
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /core/users HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Request Structure
|
||||||
|
|
||||||
|
### Request Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read|create|update|delete",
|
||||||
|
"data": {
|
||||||
|
// For create/update operations
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"columns": [...],
|
||||||
|
"preload": [...],
|
||||||
|
"filters": [...],
|
||||||
|
"sort": [...],
|
||||||
|
"limit": number,
|
||||||
|
"offset": number,
|
||||||
|
"cursor_forward": "string",
|
||||||
|
"cursor_backward": "string",
|
||||||
|
"customOperators": [...],
|
||||||
|
"computedColumns": [...]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Operations
|
||||||
|
|
||||||
|
| Operation | Description | Requires Data | Requires ID |
|
||||||
|
|-----------|-------------|---------------|-------------|
|
||||||
|
| `read` | Fetch records | No | Optional (single record) |
|
||||||
|
| `create` | Create new record(s) | Yes | No |
|
||||||
|
| `update` | Update existing record(s) | Yes | Yes (in URL) |
|
||||||
|
| `delete` | Delete record(s) | No | Yes (in URL) |
|
||||||
|
|
||||||
|
### Options Fields
|
||||||
|
|
||||||
|
| Field | Type | Description | Example |
|
||||||
|
|-------|------|-------------|---------|
|
||||||
|
| `columns` | `[]string` | Columns to select | `["id", "name", "email"]` |
|
||||||
|
| `preload` | `[]PreloadConfig` | Relations to load | See [Preloading](#preloading) |
|
||||||
|
| `filters` | `[]Filter` | Filter conditions | See [Filtering](#filtering) |
|
||||||
|
| `sort` | `[]Sort` | Sort criteria | `[{"column": "created_at", "direction": "desc"}]` |
|
||||||
|
| `limit` | `int` | Max records to return | `50` |
|
||||||
|
| `offset` | `int` | Number of records to skip | `100` |
|
||||||
|
| `cursor_forward` | `string` | Cursor for next page | `"12345"` |
|
||||||
|
| `cursor_backward` | `string` | Cursor for previous page | `"12300"` |
|
||||||
|
| `customOperators` | `[]CustomOperator` | Custom SQL conditions | See [Custom Operators](#custom-operators) |
|
||||||
|
| `computedColumns` | `[]ComputedColumn` | Virtual columns | See [Computed Columns](#computed-columns) |
|
||||||
|
|
||||||
|
## Filtering
|
||||||
|
|
||||||
|
### Available Operators
|
||||||
|
|
||||||
|
| Operator | Description | Example |
|
||||||
|
|----------|-------------|---------|
|
||||||
|
| `eq` | Equal | `{"column": "status", "operator": "eq", "value": "active"}` |
|
||||||
|
| `neq` | Not Equal | `{"column": "status", "operator": "neq", "value": "deleted"}` |
|
||||||
|
| `gt` | Greater Than | `{"column": "age", "operator": "gt", "value": 18}` |
|
||||||
|
| `gte` | Greater Than or Equal | `{"column": "age", "operator": "gte", "value": 18}` |
|
||||||
|
| `lt` | Less Than | `{"column": "price", "operator": "lt", "value": 100}` |
|
||||||
|
| `lte` | Less Than or Equal | `{"column": "price", "operator": "lte", "value": 100}` |
|
||||||
|
| `like` | LIKE pattern | `{"column": "name", "operator": "like", "value": "%john%"}` |
|
||||||
|
| `ilike` | Case-insensitive LIKE | `{"column": "email", "operator": "ilike", "value": "%@example.com"}` |
|
||||||
|
| `in` | IN clause | `{"column": "status", "operator": "in", "value": ["active", "pending"]}` |
|
||||||
|
| `contains` | Contains string | `{"column": "description", "operator": "contains", "value": "important"}` |
|
||||||
|
| `startswith` | Starts with string | `{"column": "name", "operator": "startswith", "value": "John"}` |
|
||||||
|
| `endswith` | Ends with string | `{"column": "email", "operator": "endswith", "value": "@example.com"}` |
|
||||||
|
| `between` | Between (exclusive) | `{"column": "age", "operator": "between", "value": [18, 65]}` |
|
||||||
|
| `betweeninclusive` | Between (inclusive) | `{"column": "price", "operator": "betweeninclusive", "value": [10, 100]}` |
|
||||||
|
| `empty` | IS NULL or empty | `{"column": "deleted_at", "operator": "empty"}` |
|
||||||
|
| `notempty` | IS NOT NULL | `{"column": "email", "operator": "notempty"}` |
|
||||||
|
|
||||||
|
### Complex Filtering Example
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "age",
|
||||||
|
"operator": "gte",
|
||||||
|
"value": 18
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "email",
|
||||||
|
"operator": "ilike",
|
||||||
|
"value": "%@company.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Preloading
|
||||||
|
|
||||||
|
Load related entities with custom configuration:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"columns": ["id", "name", "email"],
|
||||||
|
"preload": [
|
||||||
|
{
|
||||||
|
"relation": "posts",
|
||||||
|
"columns": ["id", "title", "created_at"],
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "published"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "created_at",
|
||||||
|
"direction": "desc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"limit": 5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"relation": "profile",
|
||||||
|
"columns": ["bio", "website"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cursor Pagination
|
||||||
|
|
||||||
|
Efficient pagination for large datasets:
|
||||||
|
|
||||||
|
### First Request (No Cursor)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "created_at",
|
||||||
|
"direction": "desc"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "id",
|
||||||
|
"direction": "asc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"limit": 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Next Page (Forward Cursor)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "created_at",
|
||||||
|
"direction": "desc"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "id",
|
||||||
|
"direction": "asc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"limit": 50,
|
||||||
|
"cursor_forward": "12345"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Previous Page (Backward Cursor)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "created_at",
|
||||||
|
"direction": "desc"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "id",
|
||||||
|
"direction": "asc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"limit": 50,
|
||||||
|
"cursor_backward": "12300"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits over offset pagination**:
|
||||||
|
* Consistent results when data changes
|
||||||
|
* Better performance for large offsets
|
||||||
|
* Prevents "skipped" or duplicate records
|
||||||
|
* Works with complex sort expressions
|
||||||
|
|
||||||
|
## Recursive CRUD Operations
|
||||||
|
|
||||||
|
Automatically handle nested object graphs with intelligent foreign key resolution.
|
||||||
|
|
||||||
|
### Creating Nested Objects
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "create",
|
||||||
|
"data": {
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
"posts": [
|
||||||
|
{
|
||||||
|
"title": "My First Post",
|
||||||
|
"content": "Hello World",
|
||||||
|
"tags": [
|
||||||
|
{"name": "tech"},
|
||||||
|
{"name": "programming"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Second Post",
|
||||||
|
"content": "More content"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"profile": {
|
||||||
|
"bio": "Software Developer",
|
||||||
|
"website": "https://example.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Per-Record Operation Control with `_request`
|
||||||
|
|
||||||
|
Control individual operations for each nested record:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "update",
|
||||||
|
"data": {
|
||||||
|
"name": "John Updated",
|
||||||
|
"posts": [
|
||||||
|
{
|
||||||
|
"_request": "insert",
|
||||||
|
"title": "New Post",
|
||||||
|
"content": "Fresh content"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_request": "update",
|
||||||
|
"id": 456,
|
||||||
|
"title": "Updated Post Title"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_request": "delete",
|
||||||
|
"id": 789
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported `_request` values**:
|
||||||
|
* `insert` - Create a new related record
|
||||||
|
* `update` - Update an existing related record
|
||||||
|
* `delete` - Delete a related record
|
||||||
|
* `upsert` - Create if doesn't exist, update if exists
|
||||||
|
|
||||||
|
**How It Works**:
|
||||||
|
1. Automatic foreign key resolution - parent IDs propagate to children
|
||||||
|
2. Recursive processing - handles nested relationships at any depth
|
||||||
|
3. Transaction safety - all operations execute atomically
|
||||||
|
4. Relationship detection - automatically detects belongsTo, hasMany, hasOne, many2many
|
||||||
|
5. Flexible operations - mix create, update, and delete in one request
|
||||||
|
|
||||||
|
## Computed Columns
|
||||||
|
|
||||||
|
Define virtual columns using SQL expressions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"columns": ["id", "first_name", "last_name"],
|
||||||
|
"computedColumns": [
|
||||||
|
{
|
||||||
|
"name": "full_name",
|
||||||
|
"expression": "CONCAT(first_name, ' ', last_name)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "age_years",
|
||||||
|
"expression": "EXTRACT(YEAR FROM AGE(birth_date))"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom Operators
|
||||||
|
|
||||||
|
Add custom SQL conditions when needed:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"condition": "LOWER(email) LIKE ?",
|
||||||
|
"values": ["%@example.com"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"condition": "created_at > NOW() - INTERVAL '7 days'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Lifecycle Hooks
|
||||||
|
|
||||||
|
Register hooks for all CRUD operations:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := resolvespec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register a before-read hook (e.g., for authorization)
|
||||||
|
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
|
||||||
|
// Check permissions
|
||||||
|
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||||
|
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify query options
|
||||||
|
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
|
||||||
|
ctx.Options.Limit = ptr(100) // Enforce max limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register an after-read hook (e.g., for data transformation)
|
||||||
|
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
|
||||||
|
// Transform or filter results
|
||||||
|
if users, ok := ctx.Result.([]User); ok {
|
||||||
|
for i := range users {
|
||||||
|
users[i].Email = maskEmail(users[i].Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a before-create hook (e.g., for validation)
|
||||||
|
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
|
||||||
|
// Validate data
|
||||||
|
if user, ok := ctx.Data.(*User); ok {
|
||||||
|
if user.Email == "" {
|
||||||
|
return fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
// Add timestamps
|
||||||
|
user.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Available Hook Types**:
|
||||||
|
* `BeforeRead`, `AfterRead`
|
||||||
|
* `BeforeCreate`, `AfterCreate`
|
||||||
|
* `BeforeUpdate`, `AfterUpdate`
|
||||||
|
* `BeforeDelete`, `AfterDelete`
|
||||||
|
|
||||||
|
**HookContext** provides:
|
||||||
|
* `Context`: Request context
|
||||||
|
* `Handler`: Access to handler, database, and registry
|
||||||
|
* `Schema`, `Entity`, `TableName`: Request info
|
||||||
|
* `Model`: The registered model type
|
||||||
|
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||||
|
* `ID`: Record ID (for single-record operations)
|
||||||
|
* `Data`: Request data (for create/update)
|
||||||
|
* `Result`: Operation result (for after hooks)
|
||||||
|
* `Writer`: Response writer (allows hooks to modify response)
|
||||||
|
|
||||||
|
## Model Registration
|
||||||
|
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||||
|
Profile *Profile `json:"profile,omitempty" gorm:"foreignKey:UserID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
UserID uint `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema.Table format
|
||||||
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
|
handler.registry.RegisterModel("core.posts", &Post{})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
UserID uint `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Connect to database
|
||||||
|
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := resolvespec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
|
handler.registry.RegisterModel("core.posts", &Post{})
|
||||||
|
|
||||||
|
// Add hooks
|
||||||
|
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
|
||||||
|
log.Printf("Reading %s", ctx.Entity)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
log.Println("Server starting on :8080")
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", router))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
ResolveSpec is designed for testability:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserRead(t *testing.T) {
|
||||||
|
handler := resolvespec.NewHandlerWithGORM(testDB)
|
||||||
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"operation": "read",
|
||||||
|
"options": map[string]interface{}{
|
||||||
|
"columns": []string{"id", "name"},
|
||||||
|
"limit": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := json.Marshal(reqBody)
|
||||||
|
req := httptest.NewRequest("POST", "/core/users", bytes.NewReader(body))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Test your handler...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Router Integration
|
||||||
|
|
||||||
|
### Gorilla Mux
|
||||||
|
|
||||||
|
```go
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
```
|
||||||
|
|
||||||
|
### BunRouter
|
||||||
|
|
||||||
|
```go
|
||||||
|
router := bunrouter.New()
|
||||||
|
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Routers
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Implement custom integration using common.Request and common.ResponseWriter
|
||||||
|
router.POST("/:schema/:entity", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
params := extractParams(r) // Your param extraction logic
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
|
||||||
|
### Success Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [...],
|
||||||
|
"metadata": {
|
||||||
|
"total": 100,
|
||||||
|
"filtered": 50,
|
||||||
|
"limit": 10,
|
||||||
|
"offset": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"error": {
|
||||||
|
"code": "validation_error",
|
||||||
|
"message": "Invalid request",
|
||||||
|
"details": "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
* [Main README](../../README.md) - ResolveSpec overview
|
||||||
|
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
|
||||||
|
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
|
||||||
|
### Success Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [...],
|
||||||
|
"metadata": {
|
||||||
|
"total": 100,
|
||||||
|
"filtered": 50,
|
||||||
|
"limit": 10,
|
||||||
|
"offset": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"error": {
|
||||||
|
"code": "validation_error",
|
||||||
|
"message": "Invalid request",
|
||||||
|
"details": "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
* [Main README](../../README.md) - ResolveSpec overview
|
||||||
|
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
|
||||||
|
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||||
@@ -318,6 +318,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||||
|
// Ensure outer parentheses to prevent OR logic from escaping
|
||||||
|
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
|
||||||
if sanitizedCursor != "" {
|
if sanitizedCursor != "" {
|
||||||
query = query.Where(sanitizedCursor)
|
query = query.Where(sanitizedCursor)
|
||||||
}
|
}
|
||||||
@@ -698,37 +700,133 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard processing without nested relations
|
// Standard processing without nested relations
|
||||||
query := h.db.NewUpdate().Table(tableName).SetMap(updates)
|
// Get the primary key name
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
// Apply conditions
|
// Wrap in transaction to ensure BeforeUpdate hook is inside transaction
|
||||||
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// First, read the existing record from the database
|
||||||
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
||||||
|
|
||||||
|
// Apply conditions to select
|
||||||
if urlID != "" {
|
if urlID != "" {
|
||||||
logger.Debug("Updating by URL ID: %s", urlID)
|
logger.Debug("Updating by URL ID: %s", urlID)
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID)
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||||
} else if reqID != nil {
|
} else if reqID != nil {
|
||||||
switch id := reqID.(type) {
|
switch id := reqID.(type) {
|
||||||
case string:
|
case string:
|
||||||
logger.Debug("Updating by request ID: %s", id)
|
logger.Debug("Updating by request ID: %s", id)
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
case []string:
|
case []string:
|
||||||
|
if len(id) > 0 {
|
||||||
logger.Debug("Updating by multiple IDs: %v", id)
|
logger.Debug("Updating by multiple IDs: %v", id)
|
||||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("no records found to update")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error fetching existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert existing record to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeUpdate hooks inside transaction
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: urlID,
|
||||||
|
Data: updates,
|
||||||
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
updates = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||||
|
for key, newValue := range updates {
|
||||||
|
// Skip if the value is nil
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if the value is an empty string
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the existing map with the new value
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build update query with merged data
|
||||||
|
query := tx.NewUpdate().Table(tableName).SetMap(existingMap)
|
||||||
|
|
||||||
|
// Apply conditions
|
||||||
|
if urlID != "" {
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||||
|
} else if reqID != nil {
|
||||||
|
switch id := reqID.(type) {
|
||||||
|
case string:
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
case []string:
|
||||||
|
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Update error: %v", err)
|
return fmt.Errorf("error updating record(s): %w", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.RowsAffected() == 0 {
|
if result.RowsAffected() == 0 {
|
||||||
logger.Warn("No records found to update")
|
return fmt.Errorf("no records found to update")
|
||||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
|
}
|
||||||
|
|
||||||
|
// Execute AfterUpdate hooks inside transaction
|
||||||
|
hookCtx.Result = updates
|
||||||
|
hookCtx.Error = nil
|
||||||
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("AfterUpdate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Update error: %v", err)
|
||||||
|
if err.Error() == "no records found to update" {
|
||||||
|
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err)
|
||||||
|
} else {
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
logger.Info("Successfully updated record(s)")
|
||||||
// Invalidate cache for this table
|
// Invalidate cache for this table
|
||||||
cacheTags := buildCacheTags(schema, tableName)
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
@@ -782,14 +880,77 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard batch update without nested relations
|
// Standard batch update without nested relations
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range updates {
|
for _, item := range updates {
|
||||||
if itemID, ok := item["id"]; ok {
|
if itemID, ok := item["id"]; ok {
|
||||||
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||||
|
|
||||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
// First, read the existing record
|
||||||
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue // Skip if record not found
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert existing record to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeUpdate hooks inside transaction
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: itemIDStr,
|
||||||
|
Data: item,
|
||||||
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
item = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge only non-null and non-empty values
|
||||||
|
for key, newValue := range item {
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterUpdate hooks inside transaction
|
||||||
|
hookCtx.Result = item
|
||||||
|
hookCtx.Error = nil
|
||||||
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -857,16 +1018,80 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard batch update without nested relations
|
// Standard batch update without nested relations
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
list := make([]interface{}, 0)
|
list := make([]interface{}, 0)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range updates {
|
for _, item := range updates {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
if itemID, ok := itemMap["id"]; ok {
|
if itemID, ok := itemMap["id"]; ok {
|
||||||
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||||
|
|
||||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
// First, read the existing record
|
||||||
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
continue // Skip if record not found
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert existing record to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeUpdate hooks inside transaction
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: itemIDStr,
|
||||||
|
Data: itemMap,
|
||||||
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
itemMap = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge only non-null and non-empty values
|
||||||
|
for key, newValue := range itemMap {
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterUpdate hooks inside transaction
|
||||||
|
hookCtx.Result = itemMap
|
||||||
|
hookCtx.Error = nil
|
||||||
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||||
|
}
|
||||||
|
|
||||||
list = append(list, item)
|
list = append(list, item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1328,30 +1553,7 @@ func isNullable(field reflect.StructField) bool {
|
|||||||
|
|
||||||
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
||||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||||
info := h.getRelationshipInfo(modelType, relationName)
|
return common.GetRelationshipInfo(modelType, relationName)
|
||||||
if info == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Convert internal type to common type
|
|
||||||
return &common.RelationshipInfo{
|
|
||||||
FieldName: info.fieldName,
|
|
||||||
JSONName: info.jsonName,
|
|
||||||
RelationType: info.relationType,
|
|
||||||
ForeignKey: info.foreignKey,
|
|
||||||
References: info.references,
|
|
||||||
JoinTable: info.joinTable,
|
|
||||||
RelatedModel: info.relatedModel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type relationshipInfo struct {
|
|
||||||
fieldName string
|
|
||||||
jsonName string
|
|
||||||
relationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
|
||||||
foreignKey string
|
|
||||||
references string
|
|
||||||
joinTable string
|
|
||||||
relatedModel interface{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||||
@@ -1371,7 +1573,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
for idx := range preloads {
|
for idx := range preloads {
|
||||||
preload := preloads[idx]
|
preload := preloads[idx]
|
||||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
relInfo := common.GetRelationshipInfo(modelType, preload.Relation)
|
||||||
if relInfo == nil {
|
if relInfo == nil {
|
||||||
logger.Warn("Relation %s not found in model", preload.Relation)
|
logger.Warn("Relation %s not found in model", preload.Relation)
|
||||||
continue
|
continue
|
||||||
@@ -1379,7 +1581,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
|
|
||||||
// Use the field name (capitalized) for ORM preloading
|
// Use the field name (capitalized) for ORM preloading
|
||||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||||
relationFieldName := relInfo.fieldName
|
relationFieldName := relInfo.FieldName
|
||||||
|
|
||||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
@@ -1422,13 +1624,13 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
copy(columns, preload.Columns)
|
copy(columns, preload.Columns)
|
||||||
|
|
||||||
// Add foreign key if not already present
|
// Add foreign key if not already present
|
||||||
if relInfo.foreignKey != "" {
|
if relInfo.ForeignKey != "" {
|
||||||
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
||||||
foreignKeyColumn := toSnakeCase(relInfo.foreignKey)
|
foreignKeyColumn := toSnakeCase(relInfo.ForeignKey)
|
||||||
|
|
||||||
hasForeignKey := false
|
hasForeignKey := false
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
if col == foreignKeyColumn || col == relInfo.foreignKey {
|
if col == foreignKeyColumn || col == relInfo.ForeignKey {
|
||||||
hasForeignKey = true
|
hasForeignKey = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1456,6 +1658,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||||
preloadOpts := &common.RequestOptions{Preload: preloads}
|
preloadOpts := &common.RequestOptions{Preload: preloads}
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||||
|
// Ensure outer parentheses to prevent OR logic from escaping
|
||||||
|
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -1474,58 +1678,6 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
return query, nil
|
return query, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
|
||||||
// Ensure we have a struct type
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
||||||
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
jsonTag := field.Tag.Get("json")
|
|
||||||
jsonName := strings.Split(jsonTag, ",")[0]
|
|
||||||
|
|
||||||
if jsonName == relationName {
|
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
info := &relationshipInfo{
|
|
||||||
fieldName: field.Name,
|
|
||||||
jsonName: jsonName,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse GORM tag to determine relationship type and keys
|
|
||||||
if strings.Contains(gormTag, "foreignKey") {
|
|
||||||
info.foreignKey = h.extractTagValue(gormTag, "foreignKey")
|
|
||||||
info.references = h.extractTagValue(gormTag, "references")
|
|
||||||
|
|
||||||
// Determine if it's belongsTo or hasMany/hasOne
|
|
||||||
if field.Type.Kind() == reflect.Slice {
|
|
||||||
info.relationType = "hasMany"
|
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
|
||||||
info.relationType = "belongsTo"
|
|
||||||
}
|
|
||||||
} else if strings.Contains(gormTag, "many2many") {
|
|
||||||
info.relationType = "many2many"
|
|
||||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) extractTagValue(tag, key string) string {
|
|
||||||
parts := strings.Split(tag, ";")
|
|
||||||
for _, part := range parts {
|
|
||||||
part = strings.TrimSpace(part)
|
|
||||||
if strings.HasPrefix(part, key+":") {
|
|
||||||
return strings.TrimPrefix(part, key+":")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
||||||
func toSnakeCase(s string) string {
|
func toSnakeCase(s string) string {
|
||||||
var result strings.Builder
|
var result strings.Builder
|
||||||
|
|||||||
@@ -269,8 +269,6 @@ func TestToSnakeCase(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractTagValue(t *testing.T) {
|
func TestExtractTagValue(t *testing.T) {
|
||||||
handler := NewHandler(nil, nil)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tag string
|
tag string
|
||||||
@@ -311,9 +309,9 @@ func TestExtractTagValue(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := handler.extractTagValue(tt.tag, tt.key)
|
result := common.ExtractTagValue(tt.tag, tt.key)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
|
t.Errorf("ExtractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,8 +50,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||||
})
|
})
|
||||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||||
@@ -98,7 +99,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
|||||||
// Set CORS headers
|
// Set CORS headers
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
@@ -106,7 +108,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
|||||||
if idParam != "" {
|
if idParam != "" {
|
||||||
vars["id"] = mux.Vars(r)[idParam]
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
handler.Handle(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,7 +119,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
|||||||
// Set CORS headers
|
// Set CORS headers
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
@@ -125,7 +128,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
|||||||
if idParam != "" {
|
if idParam != "" {
|
||||||
vars["id"] = mux.Vars(r)[idParam]
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,13 +140,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
|||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
corsConfig.AllowedMethods = allowedMethods
|
corsConfig.AllowedMethods = allowedMethods
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
// Return metadata in the OPTIONS response body
|
// Return metadata in the OPTIONS response body
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
vars["entity"] = entity
|
vars["entity"] = entity
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -207,9 +211,14 @@ func ExampleWithBun(bunDB *bun.DB) {
|
|||||||
SetupMuxRoutes(muxRouter, handler, nil)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BunRouterHandler is an interface that both bunrouter.Router and bunrouter.Group implement
|
||||||
|
type BunRouterHandler interface {
|
||||||
|
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||||
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
// Accepts bunrouter.Router or bunrouter.Group
|
||||||
r := bunRouter.GetBunRouter()
|
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||||
|
|
||||||
// CORS config
|
// CORS config
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
@@ -217,15 +226,16 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// Add global /openapi route
|
// Add global /openapi route
|
||||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -248,12 +258,13 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// POST route without ID
|
// POST route without ID
|
||||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -261,13 +272,14 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// POST route with ID
|
// POST route with ID
|
||||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -275,12 +287,13 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// GET route without ID
|
// GET route without ID
|
||||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -288,13 +301,14 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// GET route with ID
|
// GET route with ID
|
||||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -302,14 +316,15 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// OPTIONS route without ID (returns metadata)
|
// OPTIONS route without ID (returns metadata)
|
||||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
optionsCorsConfig := corsConfig
|
optionsCorsConfig := corsConfig
|
||||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -317,14 +332,15 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// OPTIONS route with ID (returns metadata)
|
// OPTIONS route with ID (returns metadata)
|
||||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
optionsCorsConfig := corsConfig
|
optionsCorsConfig := corsConfig
|
||||||
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
||||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -337,13 +353,13 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
|||||||
handler := NewHandlerWithBun(bunDB)
|
handler := NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
// Create bunrouter
|
// Create bunrouter
|
||||||
bunRouter := router.NewStandardBunRouterAdapter()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup ResolveSpec routes with bunrouter
|
// Setup ResolveSpec routes with bunrouter
|
||||||
SetupBunRouterRoutes(bunRouter, handler)
|
SetupBunRouterRoutes(bunRouter, handler)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
// http.ListenAndServe(":8080", bunRouter.GetBunRouter())
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExampleBunRouterWithBunDB shows the full uptrace stack (bunrouter + Bun ORM)
|
// ExampleBunRouterWithBunDB shows the full uptrace stack (bunrouter + Bun ORM)
|
||||||
@@ -359,11 +375,29 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
|||||||
handler := NewHandler(dbAdapter, registry)
|
handler := NewHandler(dbAdapter, registry)
|
||||||
|
|
||||||
// Create bunrouter
|
// Create bunrouter
|
||||||
bunRouter := router.NewStandardBunRouterAdapter()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup ResolveSpec routes
|
// Setup ResolveSpec routes
|
||||||
SetupBunRouterRoutes(bunRouter, handler)
|
SetupBunRouterRoutes(bunRouter, handler)
|
||||||
|
|
||||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||||
// http.ListenAndServe(":8080", bunRouter.GetBunRouter())
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleBunRouterWithGroup shows how to use SetupBunRouterRoutes with a bunrouter.Group
|
||||||
|
func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||||
|
// Create handler with Bun adapter
|
||||||
|
handler := NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
|
// Create bunrouter
|
||||||
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
|
// Create a route group with a prefix
|
||||||
|
apiGroup := bunRouter.NewGroup("/api")
|
||||||
|
|
||||||
|
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||||
|
SetupBunRouterRoutes(apiGroup, handler)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -214,14 +214,46 @@ x-expand: department:id,name,code
|
|||||||
**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation.
|
**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation.
|
||||||
|
|
||||||
#### `x-custom-sql-join`
|
#### `x-custom-sql-join`
|
||||||
Raw SQL JOIN statement.
|
Custom SQL JOIN clauses for joining tables in queries.
|
||||||
|
|
||||||
**Format:** SQL JOIN clause
|
**Format:** SQL JOIN clause or multiple clauses separated by `|`
|
||||||
|
|
||||||
|
**Single JOIN:**
|
||||||
```
|
```
|
||||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||||
```
|
```
|
||||||
|
|
||||||
⚠️ **Note:** Not yet fully implemented.
|
**Multiple JOINs:**
|
||||||
|
```
|
||||||
|
x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS)
|
||||||
|
- Multiple JOINs can be specified using the pipe `|` separator
|
||||||
|
- JOINs are sanitized for security
|
||||||
|
- Can be specified via headers or query parameters
|
||||||
|
- **Table aliases are automatically extracted and allowed for filtering and sorting**
|
||||||
|
|
||||||
|
**Using Join Aliases in Filters and Sorts:**
|
||||||
|
|
||||||
|
When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Join with alias
|
||||||
|
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||||
|
|
||||||
|
# Sort by joined table column
|
||||||
|
x-sort: d.name,employees.id
|
||||||
|
|
||||||
|
# Filter by joined table column
|
||||||
|
x-searchop-eq-d.name: Engineering
|
||||||
|
```
|
||||||
|
|
||||||
|
The system automatically:
|
||||||
|
1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`)
|
||||||
|
2. Validates that prefixed columns (like `d.name`) refer to valid join aliases
|
||||||
|
3. Allows these prefixed columns in filters and sorts
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
445
pkg/restheadspec/README.md
Normal file
445
pkg/restheadspec/README.md
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
# RestHeadSpec - Header-Based REST API
|
||||||
|
|
||||||
|
RestHeadSpec provides a REST API where all query options are passed via HTTP headers instead of the request body. This provides cleaner separation between data and metadata, making it ideal for GET requests and RESTful architectures.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* **Header-Based Querying**: All query options via HTTP headers
|
||||||
|
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||||
|
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||||
|
* **Advanced Filtering**: Field filters, search operators, AND/OR logic
|
||||||
|
* **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
|
||||||
|
* **Single Record as Object**: Automatically return single-element arrays as objects (default)
|
||||||
|
* **Base64 Support**: Base64-encoded header values for complex queries
|
||||||
|
* **Type-Aware Filtering**: Automatic type detection and conversion
|
||||||
|
* **CORS Support**: Comprehensive CORS headers for cross-origin requests
|
||||||
|
* **OPTIONS Method**: Full OPTIONS support for CORS preflight
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Setup with GORM
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
import "github.com/gorilla/mux"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// IMPORTANT: Register models BEFORE setting up routes
|
||||||
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
|
handler.Registry.RegisterModel("public.posts", &Post{})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setup with Bun ORM
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
import "github.com/uptrace/bun"
|
||||||
|
|
||||||
|
// Create handler with Bun
|
||||||
|
handler := restheadspec.NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
// Setup routes (same as GORM)
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple GET Request
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users HTTP/1.1
|
||||||
|
Host: api.example.com
|
||||||
|
X-Select-Fields: id,name,email
|
||||||
|
X-FieldFilter-Status: active
|
||||||
|
X-Sort: -created_at
|
||||||
|
X-Limit: 50
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Preloading
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users HTTP/1.1
|
||||||
|
X-Select-Fields: id,name,email,department_id
|
||||||
|
X-Preload: department:id,name
|
||||||
|
X-FieldFilter-Status: active
|
||||||
|
X-Limit: 50
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Headers
|
||||||
|
|
||||||
|
| Header | Description | Example |
|
||||||
|
|--------|-------------|---------|
|
||||||
|
| `X-Select-Fields` | Columns to include | `id,name,email` |
|
||||||
|
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
|
||||||
|
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
|
||||||
|
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
|
||||||
|
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
|
||||||
|
| `X-Preload` | Preload relations | `posts:id,title` |
|
||||||
|
| `X-Sort` | Sort columns | `-created_at,+name` |
|
||||||
|
| `X-Limit` | Limit results | `50` |
|
||||||
|
| `X-Offset` | Offset for pagination | `100` |
|
||||||
|
| `X-Clean-JSON` | Remove null/empty fields | `true` |
|
||||||
|
| `X-Single-Record-As-Object` | Return single records as objects | `false` |
|
||||||
|
|
||||||
|
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
|
||||||
|
|
||||||
|
For complete header documentation, see [HEADERS.md](HEADERS.md).
|
||||||
|
|
||||||
|
## Lifecycle Hooks
|
||||||
|
|
||||||
|
RestHeadSpec supports lifecycle hooks for all CRUD operations:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register a before-read hook (e.g., for authorization)
|
||||||
|
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||||
|
// Check permissions
|
||||||
|
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||||
|
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify query options
|
||||||
|
ctx.Options.Limit = ptr(100) // Enforce max limit
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register an after-read hook (e.g., for data transformation)
|
||||||
|
handler.Hooks.Register(restheadspec.AfterRead, func(ctx *restheadspec.HookContext) error {
|
||||||
|
// Transform or filter results
|
||||||
|
if users, ok := ctx.Result.([]User); ok {
|
||||||
|
for i := range users {
|
||||||
|
users[i].Email = maskEmail(users[i].Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a before-create hook (e.g., for validation)
|
||||||
|
handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookContext) error {
|
||||||
|
// Validate data
|
||||||
|
if user, ok := ctx.Data.(*User); ok {
|
||||||
|
if user.Email == "" {
|
||||||
|
return fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
// Add timestamps
|
||||||
|
user.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Available Hook Types**:
|
||||||
|
* `BeforeRead`, `AfterRead`
|
||||||
|
* `BeforeCreate`, `AfterCreate`
|
||||||
|
* `BeforeUpdate`, `AfterUpdate`
|
||||||
|
* `BeforeDelete`, `AfterDelete`
|
||||||
|
|
||||||
|
**HookContext** provides:
|
||||||
|
* `Context`: Request context
|
||||||
|
* `Handler`: Access to handler, database, and registry
|
||||||
|
* `Schema`, `Entity`, `TableName`: Request info
|
||||||
|
* `Model`: The registered model type
|
||||||
|
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||||
|
* `ID`: Record ID (for single-record operations)
|
||||||
|
* `Data`: Request data (for create/update)
|
||||||
|
* `Result`: Operation result (for after hooks)
|
||||||
|
* `Writer`: Response writer (allows hooks to modify response)
|
||||||
|
|
||||||
|
## Cursor Pagination
|
||||||
|
|
||||||
|
RestHeadSpec supports efficient cursor-based pagination for large datasets:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/posts HTTP/1.1
|
||||||
|
X-Sort: -created_at,+id
|
||||||
|
X-Limit: 50
|
||||||
|
X-Cursor-Forward: <cursor_token>
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works**:
|
||||||
|
1. First request returns results + cursor token in response
|
||||||
|
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
|
||||||
|
3. Cursor maintains consistent ordering even with data changes
|
||||||
|
4. Supports complex multi-column sorting
|
||||||
|
|
||||||
|
**Benefits over offset pagination**:
|
||||||
|
* Consistent results when data changes
|
||||||
|
* Better performance for large offsets
|
||||||
|
* Prevents "skipped" or duplicate records
|
||||||
|
* Works with complex sort expressions
|
||||||
|
|
||||||
|
**Example with hooks**:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Enable cursor pagination in a hook
|
||||||
|
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||||
|
// For large tables, enforce cursor pagination
|
||||||
|
if ctx.Entity == "posts" && ctx.Options.Offset != nil && *ctx.Options.Offset > 1000 {
|
||||||
|
return fmt.Errorf("use cursor pagination for large offsets")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Response Formats
|
||||||
|
|
||||||
|
RestHeadSpec supports multiple response formats:
|
||||||
|
|
||||||
|
**1. Simple Format** (`X-SimpleApi: true`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{ "id": 1, "name": "John" },
|
||||||
|
{ "id": 2, "name": "Jane" }
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Detail Format** (`X-DetailApi: true`, default):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [...],
|
||||||
|
"metadata": {
|
||||||
|
"total": 100,
|
||||||
|
"filtered": 100,
|
||||||
|
"limit": 50,
|
||||||
|
"offset": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Syncfusion Format** (`X-Syncfusion: true`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"result": [...],
|
||||||
|
"count": 100
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Single Record as Object (Default Behavior)
|
||||||
|
|
||||||
|
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses.
|
||||||
|
|
||||||
|
**Default behavior (enabled)**:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users/123
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": { "id": 123, "name": "John", "email": "john@example.com" }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**To disable** (force arrays):
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users/123
|
||||||
|
X-Single-Record-As-Object: false
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works**:
|
||||||
|
* When a query returns exactly **one record**, it's returned as an object
|
||||||
|
* When a query returns **multiple records**, they're returned as an array
|
||||||
|
* Set `X-Single-Record-As-Object: false` to always receive arrays
|
||||||
|
* Works with all response formats (simple, detail, syncfusion)
|
||||||
|
* Applies to both read operations and create/update returning clauses
|
||||||
|
|
||||||
|
## CORS & OPTIONS Support
|
||||||
|
|
||||||
|
RestHeadSpec includes comprehensive CORS support for cross-origin requests:
|
||||||
|
|
||||||
|
**OPTIONS Method**:
|
||||||
|
|
||||||
|
```http
|
||||||
|
OPTIONS /public/users HTTP/1.1
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns metadata with appropriate CORS headers:
|
||||||
|
|
||||||
|
```http
|
||||||
|
Access-Control-Allow-Origin: *
|
||||||
|
Access-Control-Allow-Methods: GET, POST, OPTIONS
|
||||||
|
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
|
||||||
|
Access-Control-Max-Age: 86400
|
||||||
|
Access-Control-Allow-Credentials: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
* OPTIONS returns model metadata (same as GET metadata endpoint)
|
||||||
|
* All HTTP methods include CORS headers automatically
|
||||||
|
* OPTIONS requests don't require authentication (CORS preflight)
|
||||||
|
* Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
|
||||||
|
* 24-hour max age to reduce preflight requests
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
|
||||||
|
// Get default CORS config
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
|
||||||
|
// Customize if needed
|
||||||
|
corsConfig.AllowedOrigins = []string{"https://example.com"}
|
||||||
|
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
### Base64 Encoding
|
||||||
|
|
||||||
|
For complex header values, use base64 encoding:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users HTTP/1.1
|
||||||
|
X-Select-Fields-Base64: aWQsbmFtZSxlbWFpbA==
|
||||||
|
```
|
||||||
|
|
||||||
|
### AND/OR Logic
|
||||||
|
|
||||||
|
Combine multiple filters with AND/OR logic:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users HTTP/1.1
|
||||||
|
X-FieldFilter-Status: active
|
||||||
|
X-SearchOp-Gte-Age: 18
|
||||||
|
X-Filter-Logic: AND
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complex Preloading
|
||||||
|
|
||||||
|
Load nested relationships:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users HTTP/1.1
|
||||||
|
X-Preload: posts:id,title,comments:id,text,author:name
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Registration
|
||||||
|
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema.Table format
|
||||||
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID uint `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Connect to database
|
||||||
|
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
// Add hooks
|
||||||
|
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||||
|
log.Printf("Reading %s", ctx.Entity)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
log.Println("Server starting on :8080")
|
||||||
|
log.Fatal(http.ListenAndServe(":8080", router))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
RestHeadSpec is designed for testability:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserRead(t *testing.T) {
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(testDB)
|
||||||
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/public/users", nil)
|
||||||
|
req.Header.Set("X-Select-Fields", "id,name")
|
||||||
|
req.Header.Set("X-Limit", "10")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
// Test your handler...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
* [HEADERS.md](HEADERS.md) - Complete header reference
|
||||||
|
* [Main README](../../README.md) - ResolveSpec overview
|
||||||
|
* [ResolveSpec Package](../resolvespec/README.md) - Body-based API
|
||||||
|
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||||
@@ -26,6 +26,7 @@ type queryCacheKey struct {
|
|||||||
Sort []common.SortOption `json:"sort"`
|
Sort []common.SortOption `json:"sort"`
|
||||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
|
CustomSQLJoin []string `json:"custom_sql_join,omitempty"`
|
||||||
Expand []expandOptionKey `json:"expand,omitempty"`
|
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||||
Distinct bool `json:"distinct,omitempty"`
|
Distinct bool `json:"distinct,omitempty"`
|
||||||
CursorForward string `json:"cursor_forward,omitempty"`
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
@@ -40,7 +41,7 @@ type cachedTotal struct {
|
|||||||
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||||
// Includes expand, distinct, and cursor pagination options
|
// Includes expand, distinct, and cursor pagination options
|
||||||
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
key := queryCacheKey{
|
key := queryCacheKey{
|
||||||
TableName: tableName,
|
TableName: tableName,
|
||||||
@@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
Sort: sort,
|
Sort: sort,
|
||||||
CustomSQLWhere: customWhere,
|
CustomSQLWhere: customWhere,
|
||||||
CustomSQLOr: customOr,
|
CustomSQLOr: customOr,
|
||||||
|
CustomSQLJoin: customJoin,
|
||||||
Distinct: distinct,
|
Distinct: distinct,
|
||||||
CursorForward: cursorFwd,
|
CursorForward: cursorFwd,
|
||||||
CursorBackward: cursorBwd,
|
CursorBackward: cursorBwd,
|
||||||
@@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
jsonData, err := json.Marshal(key)
|
jsonData, err := json.Marshal(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Fallback to simple string concatenation if JSON fails
|
// Fallback to simple string concatenation if JSON fails
|
||||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s",
|
||||||
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
|
tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||||
}
|
}
|
||||||
|
|
||||||
return hashString(string(jsonData))
|
return hashString(string(jsonData))
|
||||||
|
|||||||
193
pkg/restheadspec/empty_result_test.go
Normal file
193
pkg/restheadspec/empty_result_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test that normalizeResultArray returns empty array when no records found without ID
|
||||||
|
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
shouldBeEmptyArr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil should return empty array",
|
||||||
|
input: nil,
|
||||||
|
shouldBeEmptyArr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty slice should return empty array",
|
||||||
|
input: []*EmptyTestModel{},
|
||||||
|
shouldBeEmptyArr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single element should return the element",
|
||||||
|
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
|
||||||
|
shouldBeEmptyArr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple elements should return the slice",
|
||||||
|
input: []*EmptyTestModel{
|
||||||
|
{ID: 1, Name: "test1"},
|
||||||
|
{ID: 2, Name: "test2"},
|
||||||
|
},
|
||||||
|
shouldBeEmptyArr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.normalizeResultArray(tt.input)
|
||||||
|
|
||||||
|
// For cases that should return empty array
|
||||||
|
if tt.shouldBeEmptyArr {
|
||||||
|
emptyArr, ok := result.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Expected empty array []interface{}{}, got %T: %v", result, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(emptyArr) != 0 {
|
||||||
|
t.Errorf("Expected empty array with length 0, got length %d", len(emptyArr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it serializes to [] and not null
|
||||||
|
jsonBytes, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to marshal result: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if string(jsonBytes) != "[]" {
|
||||||
|
t.Errorf("Expected JSON '[]', got '%s'", string(jsonBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that sendFormattedResponse adds X-No-Data-Found header
|
||||||
|
func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Mock ResponseWriter
|
||||||
|
mockWriter := &MockTestResponseWriter{
|
||||||
|
headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := &common.Metadata{
|
||||||
|
Total: 0,
|
||||||
|
Count: 0,
|
||||||
|
Filtered: 0,
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
options := ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with empty data
|
||||||
|
emptyData := []interface{}{}
|
||||||
|
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options)
|
||||||
|
|
||||||
|
// Check if X-No-Data-Found header was set
|
||||||
|
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||||
|
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the body is an empty array
|
||||||
|
if mockWriter.body == nil {
|
||||||
|
t.Error("Expected body to be set, got nil")
|
||||||
|
} else {
|
||||||
|
bodyBytes, err := json.Marshal(mockWriter.body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to marshal body: %v", err)
|
||||||
|
}
|
||||||
|
// The body should be wrapped in a Response object with "data" field
|
||||||
|
bodyStr := string(bodyBytes)
|
||||||
|
if !strings.Contains(bodyStr, `"data":[]`) && !strings.Contains(bodyStr, `"result":[]`) {
|
||||||
|
t.Errorf("Expected body to contain empty array, got: %s", bodyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that sendResponseWithOptions adds X-No-Data-Found header
|
||||||
|
func TestSendResponseWithOptions_NoDataFoundHeader(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Mock ResponseWriter
|
||||||
|
mockWriter := &MockTestResponseWriter{
|
||||||
|
headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := &common.Metadata{}
|
||||||
|
options := &ExtendedRequestOptions{}
|
||||||
|
|
||||||
|
// Test with nil data
|
||||||
|
handler.sendResponseWithOptions(mockWriter, nil, metadata, options)
|
||||||
|
|
||||||
|
// Check if X-No-Data-Found header was set
|
||||||
|
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||||
|
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check status code is 200
|
||||||
|
if mockWriter.statusCode != 200 {
|
||||||
|
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the body is an empty array
|
||||||
|
if mockWriter.body == nil {
|
||||||
|
t.Error("Expected body to be set, got nil")
|
||||||
|
} else {
|
||||||
|
bodyBytes, err := json.Marshal(mockWriter.body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to marshal body: %v", err)
|
||||||
|
}
|
||||||
|
bodyStr := string(bodyBytes)
|
||||||
|
if bodyStr != "[]" {
|
||||||
|
t.Errorf("Expected body to be '[]', got: %s", bodyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTestResponseWriter for testing
|
||||||
|
type MockTestResponseWriter struct {
|
||||||
|
headers map[string]string
|
||||||
|
statusCode int
|
||||||
|
body interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockTestResponseWriter) SetHeader(key, value string) {
|
||||||
|
m.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockTestResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
m.statusCode = statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockTestResponseWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockTestResponseWriter) WriteJSON(data interface{}) error {
|
||||||
|
m.body = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockTestResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyTestModel for testing
|
||||||
|
type EmptyTestModel struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
@@ -435,9 +435,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply preloading
|
// Apply preloading
|
||||||
|
logger.Debug("Total preloads to apply: %d", len(options.Preload))
|
||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
logger.Debug("Applying preload: %s", preload.Relation)
|
logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
|
||||||
|
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
|
||||||
|
|
||||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
@@ -463,7 +465,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply filters - validate and adjust for column types first
|
// Apply filters - validate and adjust for column types first
|
||||||
for i := range options.Filters {
|
// Group consecutive OR filters together to prevent OR logic from escaping
|
||||||
|
for i := 0; i < len(options.Filters); {
|
||||||
filter := &options.Filters[i]
|
filter := &options.Filters[i]
|
||||||
|
|
||||||
// Validate and adjust filter based on column type
|
// Validate and adjust filter based on column type
|
||||||
@@ -475,8 +478,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
logicOp = "AND"
|
logicOp = "AND"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if this is the start of an OR group
|
||||||
|
if logicOp == "OR" {
|
||||||
|
// Collect all consecutive OR filters
|
||||||
|
orFilters := []*common.FilterOption{filter}
|
||||||
|
orCastInfo := []ColumnCastInfo{castInfo}
|
||||||
|
|
||||||
|
j := i + 1
|
||||||
|
for j < len(options.Filters) {
|
||||||
|
nextFilter := &options.Filters[j]
|
||||||
|
nextLogicOp := nextFilter.LogicOperator
|
||||||
|
if nextLogicOp == "" {
|
||||||
|
nextLogicOp = "AND"
|
||||||
|
}
|
||||||
|
if nextLogicOp == "OR" {
|
||||||
|
nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model)
|
||||||
|
orFilters = append(orFilters, nextFilter)
|
||||||
|
orCastInfo = append(orCastInfo, nextCastInfo)
|
||||||
|
j++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the OR group as a single grouped condition
|
||||||
|
logger.Debug("Applying OR filter group with %d conditions", len(orFilters))
|
||||||
|
query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName)
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
// Single AND filter - apply normally
|
||||||
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||||
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||||
|
i++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
@@ -486,6 +520,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||||
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
|
// Ensure outer parentheses to prevent OR logic from escaping
|
||||||
|
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||||
if sanitizedWhere != "" {
|
if sanitizedWhere != "" {
|
||||||
query = query.Where(sanitizedWhere)
|
query = query.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -497,11 +533,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
|
// Ensure outer parentheses to prevent OR logic from escaping
|
||||||
|
sanitizedOr = common.EnsureOuterParentheses(sanitizedOr)
|
||||||
if sanitizedOr != "" {
|
if sanitizedOr != "" {
|
||||||
query = query.WhereOr(sanitizedOr)
|
query = query.WhereOr(sanitizedOr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply custom SQL JOIN clauses
|
||||||
|
if len(options.CustomSQLJoin) > 0 {
|
||||||
|
for _, joinClause := range options.CustomSQLJoin {
|
||||||
|
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
|
||||||
|
// Joins are already sanitized during parsing, so we can apply them directly
|
||||||
|
query = query.Join(joinClause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If ID is provided, filter by ID
|
// If ID is provided, filter by ID
|
||||||
if id != "" {
|
if id != "" {
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
@@ -552,6 +599,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
options.Sort,
|
options.Sort,
|
||||||
options.CustomSQLWhere,
|
options.CustomSQLWhere,
|
||||||
options.CustomSQLOr,
|
options.CustomSQLOr,
|
||||||
|
options.CustomSQLJoin,
|
||||||
expandOpts,
|
expandOpts,
|
||||||
options.Distinct,
|
options.Distinct,
|
||||||
options.CursorForward,
|
options.CursorForward,
|
||||||
@@ -766,7 +814,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(preload.ComputedQL) > 0 {
|
if len(preload.ComputedQL) > 0 {
|
||||||
// Get the base table name from the related model
|
// Get the base table name from the related model
|
||||||
baseTableName := getTableNameFromModel(relatedModel)
|
baseTableName := common.GetTableNameFromModel(relatedModel)
|
||||||
|
|
||||||
// Convert the preload relation path to the appropriate alias format
|
// Convert the preload relation path to the appropriate alias format
|
||||||
// This is ORM-specific. Currently we only support Bun's format.
|
// This is ORM-specific. Currently we only support Bun's format.
|
||||||
@@ -777,7 +825,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
||||||
if strings.Contains(underlyingType, "bun.DB") {
|
if strings.Contains(underlyingType, "bun.DB") {
|
||||||
// Use Bun's alias format: lowercase with double underscores
|
// Use Bun's alias format: lowercase with double underscores
|
||||||
preloadAlias = relationPathToBunAlias(preload.Relation)
|
preloadAlias = common.RelationPathToBunAlias(preload.Relation)
|
||||||
}
|
}
|
||||||
// For GORM: GORM doesn't use the same alias format, and this fix
|
// For GORM: GORM doesn't use the same alias format, and this fix
|
||||||
// may not be needed since GORM handles preloads differently
|
// may not be needed since GORM handles preloads differently
|
||||||
@@ -792,7 +840,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
// levels of recursive/nested preloads
|
// levels of recursive/nested preloads
|
||||||
adjustedExpr := colExpr
|
adjustedExpr := colExpr
|
||||||
if baseTableName != "" && preloadAlias != "" {
|
if baseTableName != "" && preloadAlias != "" {
|
||||||
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
adjustedExpr = common.ReplaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||||
if adjustedExpr != colExpr {
|
if adjustedExpr != colExpr {
|
||||||
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||||
colName, colExpr, adjustedExpr)
|
colName, colExpr, adjustedExpr)
|
||||||
@@ -836,6 +884,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply custom SQL joins from XFiles
|
||||||
|
if len(preload.SqlJoins) > 0 {
|
||||||
|
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
|
||||||
|
for _, joinClause := range preload.SqlJoins {
|
||||||
|
sq = sq.Join(joinClause)
|
||||||
|
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters
|
||||||
if len(preload.Filters) > 0 {
|
if len(preload.Filters) > 0 {
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
@@ -861,10 +918,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||||
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||||
// First add table prefixes to unqualified columns
|
|
||||||
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
// Determine the table name to use for WHERE clause processing
|
||||||
// Then sanitize and allow preload table prefixes
|
// Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
|
||||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
tableName := preload.TableName
|
||||||
|
if tableName == "" {
|
||||||
|
tableName = reflection.ExtractTableNameOnly(preload.Relation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In Bun's Relation context, table prefixes are only needed when there are JOINs
|
||||||
|
// Without JOINs, Bun already knows which table is being queried
|
||||||
|
whereClause := preload.Where
|
||||||
|
if len(preload.SqlJoins) > 0 {
|
||||||
|
// Has JOINs: add table prefixes to disambiguate columns
|
||||||
|
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
|
||||||
|
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the WHERE clause and allow preload table prefixes
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -883,93 +955,87 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Handle recursive preloading
|
// Handle recursive preloading
|
||||||
if preload.Recursive && depth < 5 {
|
if preload.Recursive && depth < 8 {
|
||||||
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||||
|
|
||||||
// For recursive relationships, we need to get the last part of the relation path
|
|
||||||
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
|
|
||||||
relationParts := strings.Split(preload.Relation, ".")
|
relationParts := strings.Split(preload.Relation, ".")
|
||||||
lastRelationName := relationParts[len(relationParts)-1]
|
lastRelationName := relationParts[len(relationParts)-1]
|
||||||
|
|
||||||
// Create a recursive preload with the same configuration
|
// Generate FK-based relation name for children
|
||||||
// but with the relation path extended
|
// Use RecursiveChildKey if available, otherwise fall back to RelatedKey
|
||||||
recursivePreload := preload
|
recursiveFK := preload.RecursiveChildKey
|
||||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
if recursiveFK == "" {
|
||||||
|
recursiveFK = preload.RelatedKey
|
||||||
|
}
|
||||||
|
|
||||||
// Recursively apply preload until we reach depth 5
|
recursiveRelationName := lastRelationName
|
||||||
|
if recursiveFK != "" {
|
||||||
|
// Check if the last relation name already contains the FK suffix
|
||||||
|
// (this happens when XFiles already generated the FK-based name)
|
||||||
|
fkUpper := strings.ToUpper(recursiveFK)
|
||||||
|
expectedSuffix := "_" + fkUpper
|
||||||
|
|
||||||
|
if strings.HasSuffix(lastRelationName, expectedSuffix) {
|
||||||
|
// Already has FK suffix, just reuse the same name
|
||||||
|
recursiveRelationName = lastRelationName
|
||||||
|
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
|
||||||
|
} else {
|
||||||
|
// Generate FK-based name
|
||||||
|
recursiveRelationName = lastRelationName + expectedSuffix
|
||||||
|
keySource := "RelatedKey"
|
||||||
|
if preload.RecursiveChildKey != "" {
|
||||||
|
keySource = "RecursiveChildKey"
|
||||||
|
}
|
||||||
|
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
|
||||||
|
keySource, recursiveRelationName, recursiveFK)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
|
||||||
|
preload.Relation, preload.Relation, lastRelationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create recursive preload
|
||||||
|
recursivePreload := preload
|
||||||
|
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
|
||||||
|
recursivePreload.Recursive = false // Prevent infinite recursion at this level
|
||||||
|
|
||||||
|
// Use the recursive FK for child relations, not the parent's RelatedKey
|
||||||
|
if preload.RecursiveChildKey != "" {
|
||||||
|
recursivePreload.RelatedKey = preload.RecursiveChildKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
|
||||||
|
recursivePreload.Where = ""
|
||||||
|
recursivePreload.Filters = []common.FilterOption{}
|
||||||
|
logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d",
|
||||||
|
recursivePreload.Relation, depth+1)
|
||||||
|
|
||||||
|
// Apply recursively up to depth 8
|
||||||
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
||||||
|
|
||||||
|
// ALSO: Extend any child relations (like DEF) to recursive levels
|
||||||
|
baseRelation := preload.Relation + "."
|
||||||
|
for i := range allPreloads {
|
||||||
|
relatedPreload := allPreloads[i]
|
||||||
|
if strings.HasPrefix(relatedPreload.Relation, baseRelation) &&
|
||||||
|
!strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") {
|
||||||
|
childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation)
|
||||||
|
|
||||||
|
extendedChildPreload := relatedPreload
|
||||||
|
extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName
|
||||||
|
extendedChildPreload.Recursive = false
|
||||||
|
|
||||||
|
logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d",
|
||||||
|
relatedPreload.Relation, extendedChildPreload.Relation, depth+1)
|
||||||
|
|
||||||
|
query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
|
||||||
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
|
||||||
func relationPathToBunAlias(relationPath string) string {
|
|
||||||
if relationPath == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Convert to lowercase and replace dots with double underscores
|
|
||||||
alias := strings.ToLower(relationPath)
|
|
||||||
alias = strings.ReplaceAll(alias, ".", "__")
|
|
||||||
return alias
|
|
||||||
}
|
|
||||||
|
|
||||||
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
|
||||||
// with the appropriate alias for the current preload level
|
|
||||||
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
|
||||||
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
|
||||||
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
|
||||||
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
|
||||||
return sqlExpr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace both quoted and unquoted table references
|
|
||||||
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
|
||||||
|
|
||||||
// Pattern 1: tablename.column (unquoted)
|
|
||||||
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
|
||||||
|
|
||||||
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
|
||||||
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTableNameFromModel extracts the table name from a model
|
|
||||||
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
|
||||||
func getTableNameFromModel(model interface{}) string {
|
|
||||||
if model == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
|
||||||
|
|
||||||
// Unwrap pointers
|
|
||||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for bun tag on embedded BaseModel
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
if field.Anonymous {
|
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if strings.HasPrefix(bunTag, "table:") {
|
|
||||||
return strings.TrimPrefix(bunTag, "table:")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: convert struct name to lowercase (simple heuristic)
|
|
||||||
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
|
||||||
return strings.ToLower(modelType.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1177,30 +1243,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Updating record in %s.%s", schema, entity)
|
logger.Info("Updating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
// Execute BeforeUpdate hooks
|
|
||||||
hookCtx := &HookContext{
|
|
||||||
Context: ctx,
|
|
||||||
Handler: h,
|
|
||||||
Schema: schema,
|
|
||||||
Entity: entity,
|
|
||||||
TableName: tableName,
|
|
||||||
Tx: h.db,
|
|
||||||
Model: model,
|
|
||||||
Options: options,
|
|
||||||
ID: id,
|
|
||||||
Data: data,
|
|
||||||
Writer: w,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
|
||||||
logger.Error("BeforeUpdate hook failed: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified data from hook context
|
|
||||||
data = hookCtx.Data
|
|
||||||
|
|
||||||
// Convert data to map
|
// Convert data to map
|
||||||
dataMap, ok := data.(map[string]interface{})
|
dataMap, ok := data.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1234,11 +1276,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Variable to store the updated record
|
// Variable to store the updated record
|
||||||
var updatedRecord interface{}
|
var updatedRecord interface{}
|
||||||
|
|
||||||
|
// Declare hook context to be used inside and outside transaction
|
||||||
|
var hookCtx *HookContext
|
||||||
|
|
||||||
// Process nested relations if present
|
// Process nested relations if present
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
// Create temporary nested processor with transaction
|
// Create temporary nested processor with transaction
|
||||||
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||||
|
|
||||||
|
// First, read the existing record from the database
|
||||||
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("record not found with ID: %v", targetID)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to fetch existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert existing record to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Extract nested relations if present (but don't process them yet)
|
// Extract nested relations if present (but don't process them yet)
|
||||||
var nestedRelations map[string]interface{}
|
var nestedRelations map[string]interface{}
|
||||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
if h.shouldUseNestedProcessor(dataMap, model) {
|
||||||
@@ -1251,15 +1316,54 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
nestedRelations = relations
|
nestedRelations = relations
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeUpdate hooks inside transaction
|
||||||
|
hookCtx = &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Tx: tx,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: id,
|
||||||
|
Data: dataMap,
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
dataMap = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||||
|
for key, newValue := range dataMap {
|
||||||
|
// Skip if the value is nil
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if the value is an empty string
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the existing map with the new value
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure ID is in the data map for the update
|
// Ensure ID is in the data map for the update
|
||||||
dataMap[pkName] = targetID
|
existingMap[pkName] = targetID
|
||||||
|
dataMap = existingMap
|
||||||
|
|
||||||
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
||||||
// Get the type of the model, handling both pointer and non-pointer types
|
// Get the type of the model, handling both pointer and non-pointer types
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
modelType = reflection.GetPointerElement(modelType)
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
modelInstance := reflect.New(modelType).Interface()
|
modelInstance := reflect.New(modelType).Interface()
|
||||||
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
||||||
return fmt.Errorf("failed to populate model from data: %w", err)
|
return fmt.Errorf("failed to populate model from data: %w", err)
|
||||||
@@ -1297,7 +1401,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
// Fetch the updated record to return the new values
|
// Fetch the updated record to return the new values
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1563,9 +1667,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
// First, fetch the record that will be deleted
|
// First, fetch the record that will be deleted
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
modelType = reflection.GetPointerElement(modelType)
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
recordToDelete := reflect.New(modelType).Interface()
|
recordToDelete := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
@@ -1825,10 +1927,46 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
parentIDs[baseName] = parentID
|
parentIDs[baseName] = parentID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine which field name to use for setting parent ID in child data
|
||||||
|
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
||||||
|
var foreignKeyFieldName string
|
||||||
|
if relInfo.ForeignKey != "" {
|
||||||
|
// Get the JSON name for the foreign key field in the child model
|
||||||
|
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||||
|
if foreignKeyFieldName == "" {
|
||||||
|
// Fallback to lowercase field name
|
||||||
|
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fallback: use parent's primary key name
|
||||||
|
parentPKName := reflection.GetPrimaryKeyName(parentModelType)
|
||||||
|
foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName)
|
||||||
|
if foreignKeyFieldName == "" {
|
||||||
|
foreignKeyFieldName = strings.ToLower(parentPKName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||||
|
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||||
|
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||||
|
if childPKFieldName == "" {
|
||||||
|
childPKFieldName = strings.ToLower(childPKName)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s, childPK=%s",
|
||||||
|
foreignKeyFieldName, parentID, relInfo.ForeignKey, childPKFieldName)
|
||||||
|
|
||||||
// Process based on relation type and data structure
|
// Process based on relation type and data structure
|
||||||
switch v := relationValue.(type) {
|
switch v := relationValue.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Single related object
|
// Single related object - add parent ID to foreign key field
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
v[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process single relation: %w", err)
|
return fmt.Errorf("failed to process single relation: %w", err)
|
||||||
@@ -1838,6 +1976,14 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
// Multiple related objects
|
// Multiple related objects
|
||||||
for i, item := range v {
|
for i, item := range v {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
// Add parent ID to foreign key field
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
@@ -1848,6 +1994,14 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
// Multiple related objects (typed slice)
|
// Multiple related objects (typed slice)
|
||||||
for i, itemMap := range v {
|
for i, itemMap := range v {
|
||||||
|
// Add parent ID to foreign key field
|
||||||
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
|
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||||
|
} else if foreignKeyFieldName == childPKFieldName {
|
||||||
|
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||||
|
}
|
||||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
@@ -1965,6 +2119,99 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyOrFilterGroup applies a group of OR filters as a single grouped condition
|
||||||
|
// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping
|
||||||
|
func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build individual filter conditions
|
||||||
|
conditions := []string{}
|
||||||
|
args := []interface{}{}
|
||||||
|
|
||||||
|
for i, filter := range filters {
|
||||||
|
// Qualify the column name with table name if not already qualified
|
||||||
|
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
|
||||||
|
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||||
|
if castInfo[i].NeedsCast {
|
||||||
|
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the condition based on operator
|
||||||
|
condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName)
|
||||||
|
if condition != "" {
|
||||||
|
conditions = append(conditions, condition)
|
||||||
|
args = append(args, filterArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all conditions with OR and wrap in parentheses
|
||||||
|
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||||
|
logger.Debug("Applying grouped OR conditions: %s", groupedCondition)
|
||||||
|
|
||||||
|
// Apply as AND condition (the OR is already inside the parentheses)
|
||||||
|
return query.Where(groupedCondition, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||||
|
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||||
|
switch strings.ToLower(filter.Operator) {
|
||||||
|
case "eq", "equals":
|
||||||
|
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "neq", "not_equals", "ne":
|
||||||
|
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "gt", "greater_than":
|
||||||
|
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "gte", "greater_than_equals", "ge":
|
||||||
|
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "lt", "less_than":
|
||||||
|
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "lte", "less_than_equals", "le":
|
||||||
|
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "like":
|
||||||
|
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "ilike":
|
||||||
|
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "in":
|
||||||
|
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
case "between":
|
||||||
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
|
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||||
|
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||||
|
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||||
|
}
|
||||||
|
logger.Warn("Invalid BETWEEN filter value format")
|
||||||
|
return "", nil
|
||||||
|
case "between_inclusive":
|
||||||
|
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
|
||||||
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
|
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||||
|
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||||
|
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||||
|
}
|
||||||
|
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
|
||||||
|
return "", nil
|
||||||
|
case "is_null", "isnull":
|
||||||
|
// Check for NULL values - don't use cast for NULL checks
|
||||||
|
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil
|
||||||
|
case "is_not_null", "isnotnull":
|
||||||
|
// Check for NOT NULL values - don't use cast for NULL checks
|
||||||
|
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil
|
||||||
|
default:
|
||||||
|
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
|
||||||
|
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
||||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
@@ -2143,12 +2390,22 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
|
|||||||
// sendResponseWithOptions sends a response with optional formatting
|
// sendResponseWithOptions sends a response with optional formatting
|
||||||
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
||||||
w.SetHeader("Content-Type", "application/json")
|
w.SetHeader("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Handle nil data - convert to empty array
|
||||||
if data == nil {
|
if data == nil {
|
||||||
data = map[string]interface{}{}
|
data = []interface{}{}
|
||||||
w.WriteHeader(http.StatusPartialContent)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate data length after nil conversion
|
||||||
|
dataLen := reflection.Len(data)
|
||||||
|
|
||||||
|
// Add X-No-Data-Found header when no records were found
|
||||||
|
if dataLen == 0 {
|
||||||
|
w.SetHeader("X-No-Data-Found", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
// Normalize single-record arrays to objects if requested
|
// Normalize single-record arrays to objects if requested
|
||||||
if options != nil && options.SingleRecordAsObject {
|
if options != nil && options.SingleRecordAsObject {
|
||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
@@ -2165,7 +2422,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return map[string]interface{}{}
|
return []interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
@@ -2180,15 +2437,15 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
// Return the single element
|
// Return the single element
|
||||||
return dataValue.Index(0).Interface()
|
return dataValue.Index(0).Interface()
|
||||||
} else if dataValue.Len() == 0 {
|
} else if dataValue.Len() == 0 {
|
||||||
// Return empty object instead of empty array
|
// Keep empty array as empty array, don't convert to empty object
|
||||||
return map[string]interface{}{}
|
return []interface{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if dataValue.Kind() == reflect.String {
|
if dataValue.Kind() == reflect.String {
|
||||||
str := dataValue.String()
|
str := dataValue.String()
|
||||||
if str == "" || str == "null" {
|
if str == "" || str == "null" {
|
||||||
return map[string]interface{}{}
|
return []interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -2199,16 +2456,25 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||||
// Normalize single-record arrays to objects if requested
|
// Normalize single-record arrays to objects if requested
|
||||||
httpStatus := http.StatusOK
|
httpStatus := http.StatusOK
|
||||||
|
|
||||||
|
// Handle nil data - convert to empty array
|
||||||
if data == nil {
|
if data == nil {
|
||||||
data = map[string]interface{}{}
|
data = []interface{}{}
|
||||||
httpStatus = http.StatusPartialContent
|
|
||||||
} else {
|
|
||||||
dataLen := reflection.Len(data)
|
|
||||||
if dataLen == 0 {
|
|
||||||
httpStatus = http.StatusPartialContent
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate data length after nil conversion
|
||||||
|
// Note: This is done BEFORE normalization because X-No-Data-Found indicates
|
||||||
|
// whether data was found in the database, not the final response format
|
||||||
|
dataLen := reflection.Len(data)
|
||||||
|
|
||||||
|
// Add X-No-Data-Found header when no records were found
|
||||||
|
if dataLen == 0 {
|
||||||
|
w.SetHeader("X-No-Data-Found", "true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply normalization after header is set
|
||||||
|
// normalizeResultArray may convert single-element arrays to objects,
|
||||||
|
// but the X-No-Data-Found header reflects the original query result
|
||||||
if options.SingleRecordAsObject {
|
if options.SingleRecordAsObject {
|
||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
@@ -2518,10 +2784,10 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
|||||||
filteredExpand := expand
|
filteredExpand := expand
|
||||||
|
|
||||||
// Get the relationship info for this expand relation
|
// Get the relationship info for this expand relation
|
||||||
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
|
relInfo := common.GetRelationshipInfo(modelType, expand.Relation)
|
||||||
if relInfo != nil && relInfo.relatedModel != nil {
|
if relInfo != nil && relInfo.RelatedModel != nil {
|
||||||
// Create a validator for the related model
|
// Create a validator for the related model
|
||||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
expandValidator := common.NewColumnValidator(relInfo.RelatedModel)
|
||||||
// Filter columns using the related model's validator
|
// Filter columns using the related model's validator
|
||||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||||
|
|
||||||
@@ -2598,110 +2864,7 @@ func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model in
|
|||||||
|
|
||||||
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
||||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||||
info := h.getRelationshipInfo(modelType, relationName)
|
return common.GetRelationshipInfo(modelType, relationName)
|
||||||
if info == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Convert internal type to common type
|
|
||||||
return &common.RelationshipInfo{
|
|
||||||
FieldName: info.fieldName,
|
|
||||||
JSONName: info.jsonName,
|
|
||||||
RelationType: info.relationType,
|
|
||||||
ForeignKey: info.foreignKey,
|
|
||||||
References: info.references,
|
|
||||||
JoinTable: info.joinTable,
|
|
||||||
RelatedModel: info.relatedModel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type relationshipInfo struct {
|
|
||||||
fieldName string
|
|
||||||
jsonName string
|
|
||||||
relationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
|
||||||
foreignKey string
|
|
||||||
references string
|
|
||||||
joinTable string
|
|
||||||
relatedModel interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
|
||||||
// Ensure we have a struct type
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
||||||
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
jsonTag := field.Tag.Get("json")
|
|
||||||
jsonName := strings.Split(jsonTag, ",")[0]
|
|
||||||
|
|
||||||
if jsonName == relationName {
|
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
info := &relationshipInfo{
|
|
||||||
fieldName: field.Name,
|
|
||||||
jsonName: jsonName,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse GORM tag to determine relationship type and keys
|
|
||||||
if strings.Contains(gormTag, "foreignKey") {
|
|
||||||
info.foreignKey = h.extractTagValue(gormTag, "foreignKey")
|
|
||||||
info.references = h.extractTagValue(gormTag, "references")
|
|
||||||
|
|
||||||
// Determine if it's belongsTo or hasMany/hasOne
|
|
||||||
if field.Type.Kind() == reflect.Slice {
|
|
||||||
info.relationType = "hasMany"
|
|
||||||
// Get the element type for slice
|
|
||||||
elemType := field.Type.Elem()
|
|
||||||
if elemType.Kind() == reflect.Ptr {
|
|
||||||
elemType = elemType.Elem()
|
|
||||||
}
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
|
||||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
|
||||||
}
|
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
|
||||||
info.relationType = "belongsTo"
|
|
||||||
elemType := field.Type
|
|
||||||
if elemType.Kind() == reflect.Ptr {
|
|
||||||
elemType = elemType.Elem()
|
|
||||||
}
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
|
||||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if strings.Contains(gormTag, "many2many") {
|
|
||||||
info.relationType = "many2many"
|
|
||||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
|
||||||
// Get the element type for many2many (always slice)
|
|
||||||
if field.Type.Kind() == reflect.Slice {
|
|
||||||
elemType := field.Type.Elem()
|
|
||||||
if elemType.Kind() == reflect.Ptr {
|
|
||||||
elemType = elemType.Elem()
|
|
||||||
}
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
|
||||||
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Field has no GORM relationship tags, so it's not a relation
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) extractTagValue(tag, key string) string {
|
|
||||||
parts := strings.Split(tag, ";")
|
|
||||||
for _, part := range parts {
|
|
||||||
part = strings.TrimSpace(part)
|
|
||||||
if strings.HasPrefix(part, key+":") {
|
|
||||||
return strings.TrimPrefix(part, key+":")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ type ExtendedRequestOptions struct {
|
|||||||
|
|
||||||
// Joins
|
// Joins
|
||||||
Expand []ExpandOption
|
Expand []ExpandOption
|
||||||
|
CustomSQLJoin []string // Custom SQL JOIN clauses
|
||||||
|
JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation
|
||||||
|
|
||||||
// Advanced features
|
// Advanced features
|
||||||
AdvancedSQL map[string]string // Column -> SQL expression
|
AdvancedSQL map[string]string // Column -> SQL expression
|
||||||
@@ -47,6 +49,7 @@ type ExtendedRequestOptions struct {
|
|||||||
|
|
||||||
// X-Files configuration - comprehensive query options as a single JSON object
|
// X-Files configuration - comprehensive query options as a single JSON object
|
||||||
XFiles *XFiles
|
XFiles *XFiles
|
||||||
|
XFilesPresent bool // Flag to indicate if X-Files header was provided
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpandOption represents a relation expansion configuration
|
// ExpandOption represents a relation expansion configuration
|
||||||
@@ -111,6 +114,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
AdvancedSQL: make(map[string]string),
|
AdvancedSQL: make(map[string]string),
|
||||||
ComputedQL: make(map[string]string),
|
ComputedQL: make(map[string]string),
|
||||||
Expand: make([]ExpandOption, 0),
|
Expand: make([]ExpandOption, 0),
|
||||||
|
CustomSQLJoin: make([]string, 0),
|
||||||
ResponseFormat: "simple", // Default response format
|
ResponseFormat: "simple", // Default response format
|
||||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||||
}
|
}
|
||||||
@@ -185,8 +189,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
case strings.HasPrefix(key, "x-expand"):
|
case strings.HasPrefix(key, "x-expand"):
|
||||||
h.parseExpand(&options, decodedValue)
|
h.parseExpand(&options, decodedValue)
|
||||||
case strings.HasPrefix(key, "x-custom-sql-join"):
|
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||||
// TODO: Implement custom SQL join
|
h.parseCustomSQLJoin(&options, decodedValue)
|
||||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
|
||||||
|
|
||||||
// Sorting & Pagination
|
// Sorting & Pagination
|
||||||
case strings.HasPrefix(key, "x-sort"):
|
case strings.HasPrefix(key, "x-sort"):
|
||||||
@@ -272,7 +275,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relation names (convert table names to field names) if model is provided
|
// Resolve relation names (convert table names to field names) if model is provided
|
||||||
if model != nil {
|
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
|
||||||
|
if model != nil && !options.XFilesPresent {
|
||||||
h.resolveRelationNamesInOptions(&options, model)
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,6 +358,12 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
|
|||||||
operator := parts[0]
|
operator := parts[0]
|
||||||
colName := parts[1]
|
colName := parts[1]
|
||||||
|
|
||||||
|
if strings.HasPrefix(colName, "cql") {
|
||||||
|
// Computed column - Will not filter on it
|
||||||
|
logger.Warn("Search operators on computed columns are not supported: %s", colName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Map operator names to filter operators
|
// Map operator names to filter operators
|
||||||
filterOp := h.mapSearchOperator(colName, operator, value)
|
filterOp := h.mapSearchOperator(colName, operator, value)
|
||||||
|
|
||||||
@@ -489,6 +499,101 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseCustomSQLJoin parses x-custom-sql-join header
|
||||||
|
// Format: Single JOIN clause or multiple JOIN clauses separated by |
|
||||||
|
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
|
||||||
|
// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id"
|
||||||
|
func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) {
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split by | for multiple joins
|
||||||
|
joins := strings.Split(value, "|")
|
||||||
|
for _, joinStr := range joins {
|
||||||
|
joinStr = strings.TrimSpace(joinStr)
|
||||||
|
if joinStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic validation: should contain "JOIN" keyword
|
||||||
|
upperJoin := strings.ToUpper(joinStr)
|
||||||
|
if !strings.Contains(upperJoin, "JOIN") {
|
||||||
|
logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the join clause using common.SanitizeWhereClause
|
||||||
|
// Note: This is basic sanitization - in production you may want stricter validation
|
||||||
|
sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil)
|
||||||
|
if sanitizedJoin == "" {
|
||||||
|
logger.Warn("Custom SQL join failed sanitization: %s", joinStr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract table alias from the JOIN clause
|
||||||
|
alias := extractJoinAlias(sanitizedJoin)
|
||||||
|
if alias != "" {
|
||||||
|
options.JoinAliases = append(options.JoinAliases, alias)
|
||||||
|
// Also add to the embedded RequestOptions for validation
|
||||||
|
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
|
||||||
|
logger.Debug("Extracted join alias: %s", alias)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
|
||||||
|
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractJoinAlias extracts the table alias from a JOIN clause
|
||||||
|
// Examples:
|
||||||
|
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||||
|
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||||
|
// - "JOIN roles r ON ..." -> "r"
|
||||||
|
func extractJoinAlias(joinClause string) string {
|
||||||
|
// Pattern: JOIN table_name [AS] alias ON ...
|
||||||
|
// We need to extract the alias (word before ON)
|
||||||
|
|
||||||
|
upperJoin := strings.ToUpper(joinClause)
|
||||||
|
|
||||||
|
// Find the "JOIN" keyword position
|
||||||
|
joinIdx := strings.Index(upperJoin, "JOIN")
|
||||||
|
if joinIdx == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the "ON" keyword position
|
||||||
|
onIdx := strings.Index(upperJoin, " ON ")
|
||||||
|
if onIdx == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the part between JOIN and ON
|
||||||
|
betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx])
|
||||||
|
|
||||||
|
// Split by spaces to get words
|
||||||
|
words := strings.Fields(betweenJoinAndOn)
|
||||||
|
if len(words) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's an AS keyword, the alias is after it
|
||||||
|
for i, word := range words {
|
||||||
|
if strings.EqualFold(word, "AS") && i+1 < len(words) {
|
||||||
|
return words[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, the alias is the last word (if there are 2+ words)
|
||||||
|
// Format: "table_name alias" or just "table_name"
|
||||||
|
if len(words) >= 2 {
|
||||||
|
return words[len(words)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only one word means it's just the table name, no alias
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// parseSorting parses x-sort header
|
// parseSorting parses x-sort header
|
||||||
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
|
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
|
||||||
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
||||||
@@ -590,6 +695,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
|
|||||||
|
|
||||||
// Store the original XFiles for reference
|
// Store the original XFiles for reference
|
||||||
options.XFiles = &xfiles
|
options.XFiles = &xfiles
|
||||||
|
options.XFilesPresent = true // Mark that X-Files header was provided
|
||||||
|
|
||||||
// Map XFiles fields to ExtendedRequestOptions
|
// Map XFiles fields to ExtendedRequestOptions
|
||||||
|
|
||||||
@@ -881,11 +987,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the table name as-is for now - it will be resolved to field name later
|
// Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
|
||||||
// when we have the model instance available
|
// Fall back to TableName if Prefix is not specified
|
||||||
relationPath := xfile.TableName
|
relationName := xfile.Prefix
|
||||||
|
if relationName == "" {
|
||||||
|
relationName = xfile.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
|
||||||
|
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
|
||||||
|
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
|
||||||
|
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
|
||||||
|
// Check if this is a self-referencing recursive relation (same table as parent)
|
||||||
|
// by comparing the last part of basePath with the current prefix
|
||||||
|
basePathParts := strings.Split(basePath, ".")
|
||||||
|
lastPrefix := basePathParts[len(basePathParts)-1]
|
||||||
|
|
||||||
|
if lastPrefix == relationName {
|
||||||
|
// This is a recursive self-reference, use FK-based name
|
||||||
|
fkUpper := strings.ToUpper(xfile.RelatedKey)
|
||||||
|
relationName = relationName + "_" + fkUpper
|
||||||
|
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relationPath := relationName
|
||||||
if basePath != "" {
|
if basePath != "" {
|
||||||
relationPath = basePath + "." + xfile.TableName
|
relationPath = basePath + "." + relationName
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
||||||
@@ -893,6 +1021,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
// Create PreloadOption from XFiles configuration
|
// Create PreloadOption from XFiles configuration
|
||||||
preloadOpt := common.PreloadOption{
|
preloadOpt := common.PreloadOption{
|
||||||
Relation: relationPath,
|
Relation: relationPath,
|
||||||
|
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
|
||||||
Columns: xfile.Columns,
|
Columns: xfile.Columns,
|
||||||
OmitColumns: xfile.OmitColumns,
|
OmitColumns: xfile.OmitColumns,
|
||||||
}
|
}
|
||||||
@@ -935,12 +1064,12 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
// Add WHERE clause if SQL conditions specified
|
// Add WHERE clause if SQL conditions specified
|
||||||
whereConditions := make([]string, 0)
|
whereConditions := make([]string, 0)
|
||||||
if len(xfile.SqlAnd) > 0 {
|
if len(xfile.SqlAnd) > 0 {
|
||||||
// Process each SQL condition: add table prefixes and sanitize
|
// Process each SQL condition
|
||||||
|
// Note: We don't add table prefixes here because they're only needed for JOINs
|
||||||
|
// The handler will add prefixes later if SqlJoins are present
|
||||||
for _, sqlCond := range xfile.SqlAnd {
|
for _, sqlCond := range xfile.SqlAnd {
|
||||||
// First add table prefixes to unqualified columns
|
// Sanitize the condition without adding prefixes
|
||||||
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
|
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
|
||||||
// Then sanitize the condition
|
|
||||||
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
|
|
||||||
if sanitizedCond != "" {
|
if sanitizedCond != "" {
|
||||||
whereConditions = append(whereConditions, sanitizedCond)
|
whereConditions = append(whereConditions, sanitizedCond)
|
||||||
}
|
}
|
||||||
@@ -985,13 +1114,72 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer SqlJoins from XFiles to PreloadOption
|
||||||
|
if len(xfile.SqlJoins) > 0 {
|
||||||
|
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
|
||||||
|
for _, joinClause := range xfile.SqlJoins {
|
||||||
|
// Sanitize the join clause
|
||||||
|
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||||
|
if sanitizedJoin == "" {
|
||||||
|
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||||
|
|
||||||
|
// Extract join alias for validation
|
||||||
|
alias := extractJoinAlias(sanitizedJoin)
|
||||||
|
if alias != "" {
|
||||||
|
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||||
|
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||||
|
// and store the recursive child's RelatedKey for recursion generation
|
||||||
|
hasRecursiveChild := false
|
||||||
|
if len(xfile.ChildTables) > 0 {
|
||||||
|
for _, childTable := range xfile.ChildTables {
|
||||||
|
if childTable.Recursive && childTable.TableName == xfile.TableName {
|
||||||
|
hasRecursiveChild = true
|
||||||
|
preloadOpt.Recursive = true
|
||||||
|
preloadOpt.RecursiveChildKey = childTable.RelatedKey
|
||||||
|
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
|
||||||
|
relationPath, childTable.RelatedKey)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
|
||||||
|
if xfile.Recursive && basePath != "" {
|
||||||
|
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
|
||||||
|
// Still process its parent/child tables for relations like DEF
|
||||||
|
h.processXFilesRelations(xfile, options, relationPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Add the preload option
|
// Add the preload option
|
||||||
options.Preload = append(options.Preload, preloadOpt)
|
options.Preload = append(options.Preload, preloadOpt)
|
||||||
|
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
|
||||||
|
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
|
||||||
|
|
||||||
// Recursively process nested ParentTables and ChildTables
|
// Recursively process nested ParentTables and ChildTables
|
||||||
if xfile.Recursive {
|
// Skip processing child tables if we already detected and handled a recursive child
|
||||||
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
|
if hasRecursiveChild {
|
||||||
h.processXFilesRelations(xfile, options, relationPath)
|
logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
|
||||||
|
// But still process parent tables
|
||||||
|
if len(xfile.ParentTables) > 0 {
|
||||||
|
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
|
||||||
|
for _, parentTable := range xfile.ParentTables {
|
||||||
|
h.addXFilesPreload(parentTable, options, relationPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
||||||
h.processXFilesRelations(xfile, options, relationPath)
|
h.processXFilesRelations(xfile, options, relationPath)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDecodeHeaderValue(t *testing.T) {
|
func TestDecodeHeaderValue(t *testing.T) {
|
||||||
@@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
options := &ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{
|
||||||
|
Preload: make([]common.PreloadOption, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an XFiles with SqlJoins
|
||||||
|
xfile := &XFiles{
|
||||||
|
TableName: "users",
|
||||||
|
SqlJoins: []string{
|
||||||
|
"LEFT JOIN departments d ON d.id = users.department_id",
|
||||||
|
"INNER JOIN roles r ON r.id = users.role_id",
|
||||||
|
},
|
||||||
|
FilterFields: []struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Operator string `json:"operator"`
|
||||||
|
}{
|
||||||
|
{Field: "d.active", Value: "true", Operator: "eq"},
|
||||||
|
{Field: "r.name", Value: "admin", Operator: "eq"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the XFiles preload
|
||||||
|
handler.addXFilesPreload(xfile, options, "")
|
||||||
|
|
||||||
|
// Verify that a preload was added
|
||||||
|
if len(options.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
preload := options.Preload[0]
|
||||||
|
|
||||||
|
// Verify relation name
|
||||||
|
if preload.Relation != "users" {
|
||||||
|
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJoins were transferred
|
||||||
|
if len(preload.SqlJoins) != 2 {
|
||||||
|
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify JoinAliases were extracted
|
||||||
|
if len(preload.JoinAliases) != 2 {
|
||||||
|
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the aliases are correct
|
||||||
|
expectedAliases := []string{"d", "r"}
|
||||||
|
for i, expected := range expectedAliases {
|
||||||
|
if preload.JoinAliases[i] != expected {
|
||||||
|
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify filters were added
|
||||||
|
if len(preload.Filters) != 2 {
|
||||||
|
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify filter columns reference joined tables
|
||||||
|
if preload.Filters[0].Column != "d.active" {
|
||||||
|
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
|
||||||
|
}
|
||||||
|
if preload.Filters[1].Column != "r.name" {
|
||||||
|
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractJoinAlias(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
joinClause string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LEFT JOIN with alias",
|
||||||
|
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
|
||||||
|
expected: "d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INNER JOIN with AS keyword",
|
||||||
|
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
|
||||||
|
expected: "u",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JOIN without alias",
|
||||||
|
joinClause: "JOIN roles ON roles.id = users.role_id",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Complex join with multiple conditions",
|
||||||
|
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
|
||||||
|
expected: "p",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid join (no ON clause)",
|
||||||
|
joinClause: "LEFT JOIN departments",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractJoinAlias(tt.joinClause)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
||||||
// - parseSelectFields
|
// - parseSelectFields
|
||||||
// - parseFieldFilter
|
// - parseFieldFilter
|
||||||
|
|||||||
110
pkg/restheadspec/preload_tablename_test.go
Normal file
110
pkg/restheadspec/preload_tablename_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestPreloadOption_TableName verifies that TableName field is properly used
|
||||||
|
// when provided in PreloadOption for WHERE clause processing
|
||||||
|
func TestPreloadOption_TableName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preload common.PreloadOption
|
||||||
|
expectedTable string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TableName provided explicitly",
|
||||||
|
preload: common.PreloadOption{
|
||||||
|
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||||
|
TableName: "mastertaskitem",
|
||||||
|
Where: "rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
expectedTable: "mastertaskitem",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TableName empty, should use empty string",
|
||||||
|
preload: common.PreloadOption{
|
||||||
|
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||||
|
TableName: "",
|
||||||
|
Where: "rid_parentmastertaskitem is null",
|
||||||
|
},
|
||||||
|
expectedTable: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple relation without nested path",
|
||||||
|
preload: common.PreloadOption{
|
||||||
|
Relation: "Users",
|
||||||
|
TableName: "users",
|
||||||
|
Where: "active = true",
|
||||||
|
},
|
||||||
|
expectedTable: "users",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test that the TableName field stores the correct value
|
||||||
|
if tt.preload.TableName != tt.expectedTable {
|
||||||
|
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that when TableName is provided, it should be used instead of extracting from relation
|
||||||
|
tableName := tt.preload.TableName
|
||||||
|
if tableName == "" {
|
||||||
|
// This simulates the fallback logic in handler.go
|
||||||
|
// In reality, reflection.ExtractTableNameOnly would be called
|
||||||
|
tableName = tt.expectedTable
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableName != tt.expectedTable {
|
||||||
|
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestXFilesPreload_StoresTableName verifies that XFiles processing
|
||||||
|
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
|
||||||
|
func TestXFilesPreload_StoresTableName(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
xfiles := &XFiles{
|
||||||
|
TableName: "mastertaskitem",
|
||||||
|
Prefix: "MAL",
|
||||||
|
PrimaryKey: "rid_mastertaskitem",
|
||||||
|
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
|
||||||
|
Recursive: false, // Changed from true (recursive children are now skipped)
|
||||||
|
SqlAnd: []string{"rid_parentmastertaskitem is null"},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := &ExtendedRequestOptions{}
|
||||||
|
|
||||||
|
// Process XFiles
|
||||||
|
handler.addXFilesPreload(xfiles, options, "MTL")
|
||||||
|
|
||||||
|
// Verify that a preload was added
|
||||||
|
if len(options.Preload) == 0 {
|
||||||
|
t.Fatal("Expected at least one preload to be added")
|
||||||
|
}
|
||||||
|
|
||||||
|
preload := options.Preload[0]
|
||||||
|
|
||||||
|
// Verify the table name is stored
|
||||||
|
if preload.TableName != "mastertaskitem" {
|
||||||
|
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the relation path includes the prefix
|
||||||
|
expectedRelation := "MTL.MAL"
|
||||||
|
if preload.Relation != expectedRelation {
|
||||||
|
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
|
||||||
|
expectedWhere := "rid_parentmastertaskitem is null"
|
||||||
|
if preload.Where != expectedWhere {
|
||||||
|
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
|
||||||
|
}
|
||||||
|
}
|
||||||
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
|
||||||
|
// to WHERE clauses when SqlJoins are present
|
||||||
|
func TestPreloadWhereClause_WithJoins(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
sqlJoins []string
|
||||||
|
expectedPrefix bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No joins - no prefix needed",
|
||||||
|
where: "status = 'active'",
|
||||||
|
sqlJoins: []string{},
|
||||||
|
expectedPrefix: false,
|
||||||
|
description: "Without JOINs, Bun knows the table context",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Has joins - prefix needed",
|
||||||
|
where: "status = 'active'",
|
||||||
|
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
|
||||||
|
expectedPrefix: true,
|
||||||
|
description: "With JOINs, table prefix disambiguates columns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Already has prefix - no change",
|
||||||
|
where: "users.status = 'active'",
|
||||||
|
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
|
||||||
|
expectedPrefix: true,
|
||||||
|
description: "Existing prefix should be preserved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// This test documents the expected behavior
|
||||||
|
// The actual logic is in handler.go lines 916-937
|
||||||
|
|
||||||
|
hasJoins := len(tt.sqlJoins) > 0
|
||||||
|
if hasJoins != tt.expectedPrefix {
|
||||||
|
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
|
||||||
|
hasJoins, tt.expectedPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("%s: %s", tt.name, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
|
||||||
|
// results in table prefixes being added to WHERE clauses
|
||||||
|
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
xfiles := &XFiles{
|
||||||
|
TableName: "users",
|
||||||
|
Prefix: "USR",
|
||||||
|
PrimaryKey: "id",
|
||||||
|
SqlAnd: []string{"status = 'active'"},
|
||||||
|
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := &ExtendedRequestOptions{}
|
||||||
|
handler.addXFilesPreload(xfiles, options, "")
|
||||||
|
|
||||||
|
if len(options.Preload) == 0 {
|
||||||
|
t.Fatal("Expected at least one preload to be added")
|
||||||
|
}
|
||||||
|
|
||||||
|
preload := options.Preload[0]
|
||||||
|
|
||||||
|
// Verify SqlJoins were stored
|
||||||
|
if len(preload.SqlJoins) != 1 {
|
||||||
|
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify WHERE clause does NOT have prefix yet (added later in handler)
|
||||||
|
expectedWhere := "status = 'active'"
|
||||||
|
if preload.Where != expectedWhere {
|
||||||
|
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: The handler will add the prefix when it sees SqlJoins
|
||||||
|
// This is tested in the handler itself, not during XFiles parsing
|
||||||
|
}
|
||||||
@@ -301,6 +301,163 @@ func TestParseOptionsFromQueryParams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL JOIN from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.CustomSQLJoin) == 0 {
|
||||||
|
t.Error("Expected CustomSQLJoin to be set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(options.CustomSQLJoin) != 1 {
|
||||||
|
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected := `LEFT JOIN departments d ON d.id = employees.department_id`
|
||||||
|
if options.CustomSQLJoin[0] != expected {
|
||||||
|
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse multiple custom SQL JOINs from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.CustomSQLJoin) != 2 {
|
||||||
|
t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected1 := `LEFT JOIN departments d ON d.id = e.dept_id`
|
||||||
|
expected2 := `INNER JOIN roles r ON r.id = e.role_id`
|
||||||
|
if options.CustomSQLJoin[0] != expected1 {
|
||||||
|
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0])
|
||||||
|
}
|
||||||
|
if options.CustomSQLJoin[1] != expected2 {
|
||||||
|
t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL JOIN from headers",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.CustomSQLJoin) == 0 {
|
||||||
|
t.Error("Expected CustomSQLJoin to be set from header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected := `LEFT JOIN users u ON u.id = posts.user_id`
|
||||||
|
if options.CustomSQLJoin[0] != expected {
|
||||||
|
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract aliases from custom SQL JOIN",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.JoinAliases) == 0 {
|
||||||
|
t.Error("Expected JoinAliases to be extracted")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(options.JoinAliases) != 1 {
|
||||||
|
t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.JoinAliases[0] != "d" {
|
||||||
|
t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0])
|
||||||
|
}
|
||||||
|
// Also check that it's in the embedded RequestOptions
|
||||||
|
if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" {
|
||||||
|
t.Error("Expected join alias to also be in RequestOptions.JoinAliases")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract multiple aliases from multiple custom SQL JOINs",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.JoinAliases) != 2 {
|
||||||
|
t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expectedAliases := []string{"d", "r"}
|
||||||
|
for i, expected := range expectedAliases {
|
||||||
|
if options.JoinAliases[i] != expected {
|
||||||
|
t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Custom JOIN with sort on joined table",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||||
|
"x-sort": "d.name,employees.id",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
// Verify join was added
|
||||||
|
if len(options.CustomSQLJoin) != 1 {
|
||||||
|
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Verify alias was extracted
|
||||||
|
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||||
|
t.Error("Expected join alias 'd' to be extracted")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Verify sort was parsed
|
||||||
|
if len(options.Sort) != 2 {
|
||||||
|
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.Sort[0].Column != "d.name" {
|
||||||
|
t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column)
|
||||||
|
}
|
||||||
|
if options.Sort[1].Column != "employees.id" {
|
||||||
|
t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Custom JOIN with filter on joined table",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||||
|
"x-searchop-eq-d.name": "Engineering",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
// Verify join was added
|
||||||
|
if len(options.CustomSQLJoin) != 1 {
|
||||||
|
t.Error("Expected 1 custom SQL join")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Verify alias was extracted
|
||||||
|
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||||
|
t.Error("Expected join alias 'd' to be extracted")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Verify filter was parsed
|
||||||
|
if len(options.Filters) != 1 {
|
||||||
|
t.Errorf("Expected 1 filter, got %d", len(options.Filters))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.Filters[0].Column != "d.name" {
|
||||||
|
t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column)
|
||||||
|
}
|
||||||
|
if options.Filters[0].Operator != "eq" {
|
||||||
|
t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -395,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function
|
||||||
|
func TestCustomJoinAliasExtraction(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
join string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LEFT JOIN with alias",
|
||||||
|
join: "LEFT JOIN departments d ON d.id = employees.department_id",
|
||||||
|
expected: "d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INNER JOIN with AS keyword",
|
||||||
|
join: "INNER JOIN users AS u ON u.id = posts.user_id",
|
||||||
|
expected: "u",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple JOIN with alias",
|
||||||
|
join: "JOIN roles r ON r.id = user_roles.role_id",
|
||||||
|
expected: "r",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JOIN without alias (just table name)",
|
||||||
|
join: "JOIN departments ON departments.id = employees.dept_id",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RIGHT JOIN with alias",
|
||||||
|
join: "RIGHT JOIN orders o ON o.customer_id = customers.id",
|
||||||
|
expected: "o",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "FULL OUTER JOIN with AS",
|
||||||
|
join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id",
|
||||||
|
expected: "p",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractJoinAlias(tt.join)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to check if a string contains a substring
|
// Helper function to check if a string contains a substring
|
||||||
func contains(s, substr string) bool {
|
func contains(s, substr string) bool {
|
||||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||||
|
|||||||
391
pkg/restheadspec/recursive_preload_test.go
Normal file
391
pkg/restheadspec/recursive_preload_test.go
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
//go:build !integration
|
||||||
|
// +build !integration
|
||||||
|
|
||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRecursivePreloadClearsWhereClause tests that recursive preloads
|
||||||
|
// correctly clear the WHERE clause from the parent level to allow
|
||||||
|
// Bun to use foreign key relationships for loading children
|
||||||
|
func TestRecursivePreloadClearsWhereClause(t *testing.T) {
|
||||||
|
// Create a mock handler
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Create a preload option with a WHERE clause that filters root items
|
||||||
|
// This simulates the xfiles use case where the first level has a filter
|
||||||
|
// like "rid_parentmastertaskitem is null" to get root items
|
||||||
|
preload := common.PreloadOption{
|
||||||
|
Relation: "MastertaskItems",
|
||||||
|
Recursive: true,
|
||||||
|
RelatedKey: "rid_parentmastertaskitem",
|
||||||
|
Where: "rid_parentmastertaskitem is null",
|
||||||
|
Filters: []common.FilterOption{
|
||||||
|
{
|
||||||
|
Column: "rid_parentmastertaskitem",
|
||||||
|
Operator: "is null",
|
||||||
|
Value: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mock query that tracks operations
|
||||||
|
mockQuery := &mockSelectQuery{
|
||||||
|
operations: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the recursive preload at depth 0
|
||||||
|
// This should:
|
||||||
|
// 1. Apply the initial preload with the WHERE clause
|
||||||
|
// 2. Create a recursive preload without the WHERE clause
|
||||||
|
allPreloads := []common.PreloadOption{preload}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||||
|
|
||||||
|
// Verify the mock query received the operations
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// Check that we have at least 2 PreloadRelation calls:
|
||||||
|
// 1. The initial "MastertaskItems" with WHERE clause
|
||||||
|
// 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause
|
||||||
|
preloadCount := 0
|
||||||
|
recursivePreloadFound := false
|
||||||
|
whereAppliedToRecursive := false
|
||||||
|
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MastertaskItems" {
|
||||||
|
preloadCount++
|
||||||
|
}
|
||||||
|
if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" {
|
||||||
|
recursivePreloadFound = true
|
||||||
|
}
|
||||||
|
// Check if WHERE was applied to the recursive preload (it shouldn't be)
|
||||||
|
if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound {
|
||||||
|
whereAppliedToRecursive = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if preloadCount < 1 {
|
||||||
|
t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !recursivePreloadFound {
|
||||||
|
t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
if whereAppliedToRecursive {
|
||||||
|
t.Error("WHERE clause should not be applied to recursive preload levels")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRecursivePreloadWithChildRelations tests that child relations
|
||||||
|
// (like DEF in MAL.DEF) are properly extended to recursive levels
|
||||||
|
func TestRecursivePreloadWithChildRelations(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Create the main recursive preload
|
||||||
|
recursivePreload := common.PreloadOption{
|
||||||
|
Relation: "MAL",
|
||||||
|
Recursive: true,
|
||||||
|
RelatedKey: "rid_parentmastertaskitem",
|
||||||
|
Where: "rid_parentmastertaskitem is null",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a child relation that should be extended
|
||||||
|
childPreload := common.PreloadOption{
|
||||||
|
Relation: "MAL.DEF",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{
|
||||||
|
operations: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
allPreloads := []common.PreloadOption{recursivePreload, childPreload}
|
||||||
|
|
||||||
|
// Apply both preloads - the child preload should be extended when the recursive one processes
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0)
|
||||||
|
|
||||||
|
// Also need to apply the child preload separately (as would happen in normal flow)
|
||||||
|
result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0)
|
||||||
|
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// Check that the child relation was extended to recursive levels
|
||||||
|
// We should see:
|
||||||
|
// - MAL (with WHERE)
|
||||||
|
// - MAL.DEF
|
||||||
|
// - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE)
|
||||||
|
// - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic)
|
||||||
|
foundMALDEF := false
|
||||||
|
foundRecursiveMAL := false
|
||||||
|
foundMALMALDEF := false
|
||||||
|
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.DEF" {
|
||||||
|
foundMALDEF = true
|
||||||
|
}
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundRecursiveMAL = true
|
||||||
|
}
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||||
|
foundMALMALDEF = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundMALDEF {
|
||||||
|
t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundRecursiveMAL {
|
||||||
|
t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundMALMALDEF {
|
||||||
|
t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive
|
||||||
|
// preload generates the correct FK-based relation name using RelatedKey
|
||||||
|
func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Test case 1: With RelatedKey - should generate FK-based name
|
||||||
|
t.Run("WithRelatedKey", func(t *testing.T) {
|
||||||
|
preload := common.PreloadOption{
|
||||||
|
Relation: "MAL",
|
||||||
|
Recursive: true,
|
||||||
|
RelatedKey: "rid_parentmastertaskitem",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
allPreloads := []common.PreloadOption{preload}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||||
|
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// Should generate MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||||
|
foundCorrectRelation := false
|
||||||
|
foundIncorrectRelation := false
|
||||||
|
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundCorrectRelation = true
|
||||||
|
}
|
||||||
|
if op == "PreloadRelation:MAL.MAL" {
|
||||||
|
foundIncorrectRelation = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundCorrectRelation {
|
||||||
|
t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundIncorrectRelation {
|
||||||
|
t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 2: Without RelatedKey - should fallback to old behavior
|
||||||
|
t.Run("WithoutRelatedKey", func(t *testing.T) {
|
||||||
|
preload := common.PreloadOption{
|
||||||
|
Relation: "MAL",
|
||||||
|
Recursive: true,
|
||||||
|
// No RelatedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
allPreloads := []common.PreloadOption{preload}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||||
|
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// Should fallback to MAL.MAL
|
||||||
|
foundFallback := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL" {
|
||||||
|
foundFallback = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundFallback {
|
||||||
|
t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 3: Depth limit of 8
|
||||||
|
t.Run("DepthLimit", func(t *testing.T) {
|
||||||
|
preload := common.PreloadOption{
|
||||||
|
Relation: "MAL",
|
||||||
|
Recursive: true,
|
||||||
|
RelatedKey: "rid_parentmastertaskitem",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
allPreloads := []common.PreloadOption{preload}
|
||||||
|
|
||||||
|
// Start at depth 7 - should create one more level
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
foundDepth8 := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundDepth8 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundDepth8 {
|
||||||
|
t.Error("Expected to create recursive level at depth 8")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start at depth 8 - should NOT create another level
|
||||||
|
mockQuery2 := &mockSelectQuery{operations: []string{}}
|
||||||
|
result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8)
|
||||||
|
mock2 := result2.(*mockSelectQuery)
|
||||||
|
|
||||||
|
foundDepth9 := false
|
||||||
|
for _, op := range mock2.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundDepth9 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundDepth9 {
|
||||||
|
t.Error("Should NOT create recursive level beyond depth 8")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockSelectQuery implements common.SelectQuery for testing
|
||||||
|
type mockSelectQuery struct {
|
||||||
|
operations []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Model")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Table:"+table)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||||
|
for _, col := range columns {
|
||||||
|
m.operations = append(m.operations, "Column:"+col)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Where:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "WhereOr:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "WhereIn:"+column)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Order:"+order)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Limit")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Offset")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Join:"+join)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Group")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Having:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Preload:"+relation)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||||
|
// Apply the preload modifiers
|
||||||
|
for _, fn := range apply {
|
||||||
|
fn(m)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||||
|
m.operations = append(m.operations, "Scan")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||||
|
m.operations = append(m.operations, "ScanModel")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||||
|
m.operations = append(m.operations, "Count")
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||||
|
m.operations = append(m.operations, "Exists")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) GetModel() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -32,6 +32,7 @@
|
|||||||
// - X-Clean-JSON: Boolean to remove null/empty fields
|
// - X-Clean-JSON: Boolean to remove null/empty fields
|
||||||
// - X-Custom-SQL-Where: Custom SQL WHERE clause (AND)
|
// - X-Custom-SQL-Where: Custom SQL WHERE clause (AND)
|
||||||
// - X-Custom-SQL-Or: Custom SQL WHERE clause (OR)
|
// - X-Custom-SQL-Or: Custom SQL WHERE clause (OR)
|
||||||
|
// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple)
|
||||||
//
|
//
|
||||||
// # Usage Example
|
// # Usage Example
|
||||||
//
|
//
|
||||||
@@ -103,8 +104,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||||
})
|
})
|
||||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||||
@@ -161,7 +163,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
|||||||
// Set CORS headers
|
// Set CORS headers
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
@@ -169,7 +172,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
|||||||
if idParam != "" {
|
if idParam != "" {
|
||||||
vars["id"] = mux.Vars(r)[idParam]
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
handler.Handle(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +183,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
|||||||
// Set CORS headers
|
// Set CORS headers
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
@@ -188,7 +192,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
|||||||
if idParam != "" {
|
if idParam != "" {
|
||||||
vars["id"] = mux.Vars(r)[idParam]
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -200,13 +204,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
|||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
corsConfig.AllowedMethods = allowedMethods
|
corsConfig.AllowedMethods = allowedMethods
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
|
||||||
// Return metadata in the OPTIONS response body
|
// Return metadata in the OPTIONS response body
|
||||||
vars := make(map[string]string)
|
vars := make(map[string]string)
|
||||||
vars["schema"] = schema
|
vars["schema"] = schema
|
||||||
vars["entity"] = entity
|
vars["entity"] = entity
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -270,9 +275,14 @@ func ExampleWithBun(bunDB *bun.DB) {
|
|||||||
SetupMuxRoutes(muxRouter, handler, nil)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BunRouterHandler is an interface that both bunrouter.Router and bunrouter.Group implement
|
||||||
|
type BunRouterHandler interface {
|
||||||
|
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||||
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
// Accepts bunrouter.Router or bunrouter.Group
|
||||||
r := bunRouter.GetBunRouter()
|
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||||
|
|
||||||
// CORS config
|
// CORS config
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
@@ -280,15 +290,8 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// Add global /openapi route
|
// Add global /openapi route
|
||||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -312,24 +315,26 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// GET and POST for /{schema}/{entity}
|
// GET and POST for /{schema}/{entity}
|
||||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -337,65 +342,70 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
"id": req.Param("id"),
|
"id": req.Param("id"),
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -403,12 +413,13 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// Metadata endpoint
|
// Metadata endpoint
|
||||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -416,14 +427,15 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// OPTIONS route without ID (returns metadata)
|
// OPTIONS route without ID (returns metadata)
|
||||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
optionsCorsConfig := corsConfig
|
optionsCorsConfig := corsConfig
|
||||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -431,14 +443,15 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
// OPTIONS route with ID (returns metadata)
|
// OPTIONS route with ID (returns metadata)
|
||||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
optionsCorsConfig := corsConfig
|
optionsCorsConfig := corsConfig
|
||||||
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
|
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
|
||||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": currentSchema,
|
"schema": currentSchema,
|
||||||
"entity": currentEntity,
|
"entity": currentEntity,
|
||||||
}
|
}
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -450,17 +463,34 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
|||||||
// Create handler
|
// Create handler
|
||||||
handler := NewHandlerWithBun(bunDB)
|
handler := NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
// Create BunRouter adapter
|
// Create bunrouter
|
||||||
routerAdapter := NewStandardBunRouter()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
SetupBunRouterRoutes(routerAdapter, handler)
|
SetupBunRouterRoutes(bunRouter, handler)
|
||||||
|
|
||||||
// Get the underlying router for server setup
|
|
||||||
r := routerAdapter.GetBunRouter()
|
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
if err := http.ListenAndServe(":8080", r); err != nil {
|
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||||
|
logger.Error("Server failed to start: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleBunRouterWithGroup shows how to use SetupBunRouterRoutes with a bunrouter.Group
|
||||||
|
func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||||
|
// Create handler with Bun adapter
|
||||||
|
handler := NewHandlerWithBun(bunDB)
|
||||||
|
|
||||||
|
// Create bunrouter
|
||||||
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
|
// Create a route group with a prefix
|
||||||
|
apiGroup := bunRouter.NewGroup("/api")
|
||||||
|
|
||||||
|
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||||
|
SetupBunRouterRoutes(apiGroup, handler)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||||
logger.Error("Server failed to start: %v", err)
|
logger.Error("Server failed to start: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseModelName(t *testing.T) {
|
func TestParseModelName(t *testing.T) {
|
||||||
@@ -112,3 +114,88 @@ func TestNewStandardBunRouter(t *testing.T) {
|
|||||||
t.Error("Expected router to be created, got nil")
|
t.Error("Expected router to be created, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractTagValue(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tag string
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Extract existing key",
|
||||||
|
tag: "json:name;validate:required",
|
||||||
|
key: "json",
|
||||||
|
expected: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key with spaces",
|
||||||
|
tag: "json:name ; validate:required",
|
||||||
|
key: "validate",
|
||||||
|
expected: "required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key at end",
|
||||||
|
tag: "json:name;validate:required;db:column_name",
|
||||||
|
key: "db",
|
||||||
|
expected: "column_name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Extract key at beginning",
|
||||||
|
tag: "primary:true;json:id;db:user_id",
|
||||||
|
key: "primary",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key not found",
|
||||||
|
tag: "json:name;validate:required",
|
||||||
|
key: "db",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty tag",
|
||||||
|
tag: "",
|
||||||
|
key: "json",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single key-value pair",
|
||||||
|
tag: "json:name",
|
||||||
|
key: "json",
|
||||||
|
expected: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key with empty value",
|
||||||
|
tag: "json:;validate:required",
|
||||||
|
key: "json",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Key with complex value",
|
||||||
|
tag: "json:user_name,omitempty;validate:required,min=3",
|
||||||
|
key: "json",
|
||||||
|
expected: "user_name,omitempty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple semicolons",
|
||||||
|
tag: "json:name;;validate:required",
|
||||||
|
key: "validate",
|
||||||
|
expected: "required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BUN Tag",
|
||||||
|
tag: "rel:has-many,join:rid_hub=rid_hub_child",
|
||||||
|
key: "join",
|
||||||
|
expected: "rid_hub=rid_hub_child",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := common.ExtractTagValue(tt.tag, tt.key)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
//go:build integration
|
||||||
|
// +build integration
|
||||||
|
|
||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockSelectQuery implements common.SelectQuery for testing (integration version)
|
||||||
|
type mockSelectQuery struct {
|
||||||
|
operations []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Model")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Table:"+table)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||||
|
for _, col := range columns {
|
||||||
|
m.operations = append(m.operations, "Column:"+col)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Where:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "WhereOr:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "WhereIn:"+column)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Order:"+order)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Limit")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Offset")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Join:"+join)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Group")
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Having:"+query)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "Preload:"+relation)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||||
|
// Apply the preload modifiers
|
||||||
|
for _, fn := range apply {
|
||||||
|
fn(m)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||||
|
m.operations = append(m.operations, "Scan")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||||
|
m.operations = append(m.operations, "ScanModel")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||||
|
m.operations = append(m.operations, "Count")
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||||
|
m.operations = append(m.operations, "Exists")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSelectQuery) GetModel() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestXFilesRecursivePreload is an integration test that validates the XFiles
|
||||||
|
// recursive preload functionality using real test data files.
|
||||||
|
//
|
||||||
|
// This test ensures:
|
||||||
|
// 1. XFiles request JSON is correctly parsed into PreloadOptions
|
||||||
|
// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM)
|
||||||
|
// 3. Parent WHERE clauses don't leak to child levels
|
||||||
|
// 4. Child relations (like DEF) are extended to all recursive levels
|
||||||
|
// 5. Hierarchical data structure matches expected output
|
||||||
|
func TestXFilesRecursivePreload(t *testing.T) {
|
||||||
|
// Load the XFiles request configuration
|
||||||
|
requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json")
|
||||||
|
requestData, err := os.ReadFile(requestPath)
|
||||||
|
require.NoError(t, err, "Failed to read xfiles.request.json")
|
||||||
|
|
||||||
|
var xfileConfig XFiles
|
||||||
|
err = json.Unmarshal(requestData, &xfileConfig)
|
||||||
|
require.NoError(t, err, "Failed to parse xfiles.request.json")
|
||||||
|
|
||||||
|
// Create handler and parse XFiles into PreloadOptions
|
||||||
|
handler := &Handler{}
|
||||||
|
options := &ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{
|
||||||
|
Preload: []common.PreloadOption{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the XFiles configuration - start with the root table
|
||||||
|
handler.processXFilesRelations(&xfileConfig, options, "")
|
||||||
|
|
||||||
|
// Verify that preload options were created
|
||||||
|
require.NotEmpty(t, options.Preload, "Expected preload options to be created")
|
||||||
|
|
||||||
|
// Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
|
||||||
|
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
|
||||||
|
// Find the mastertaskitem preload - it should be marked as recursive
|
||||||
|
var recursivePreload *common.PreloadOption
|
||||||
|
for i := range options.Preload {
|
||||||
|
preload := &options.Preload[i]
|
||||||
|
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||||
|
recursivePreload = preload
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||||
|
|
||||||
|
// RelatedKey should be the parent relationship key (MTL -> MAL)
|
||||||
|
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
|
||||||
|
"Recursive preload should preserve original RelatedKey for parent relationship")
|
||||||
|
|
||||||
|
// RecursiveChildKey should be set from the recursive child config
|
||||||
|
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
|
||||||
|
"Recursive preload should have RecursiveChildKey set from recursive child config")
|
||||||
|
|
||||||
|
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 2: Verify mastertaskitem has WHERE clause for filtering root items
|
||||||
|
t.Run("RootLevelHasWhereClause", func(t *testing.T) {
|
||||||
|
var rootPreload *common.PreloadOption
|
||||||
|
for i := range options.Preload {
|
||||||
|
preload := &options.Preload[i]
|
||||||
|
if preload.Relation == "MTL.MAL" {
|
||||||
|
rootPreload = preload
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
|
||||||
|
assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
|
||||||
|
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
|
||||||
|
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 3: Verify actiondefinition relation exists for mastertaskitem
|
||||||
|
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||||
|
var defPreload *common.PreloadOption
|
||||||
|
for i := range options.Preload {
|
||||||
|
preload := &options.Preload[i]
|
||||||
|
if preload.Relation == "MTL.MAL.DEF" {
|
||||||
|
defPreload = preload
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem")
|
||||||
|
assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey,
|
||||||
|
"actiondefinition preload should have ForeignKey set")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 4: Verify relation name generation with mock query
|
||||||
|
t.Run("RelationNameGeneration", func(t *testing.T) {
|
||||||
|
// Find the mastertaskitem preload - it should be marked as recursive
|
||||||
|
var recursivePreload common.PreloadOption
|
||||||
|
found := false
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||||
|
recursivePreload = preload
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||||
|
|
||||||
|
// Create mock query to track operations
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
|
||||||
|
// Apply the recursive preload
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// Verify the correct FK-based relation name was generated
|
||||||
|
foundCorrectRelation := false
|
||||||
|
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
// Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||||
|
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundCorrectRelation = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, foundCorrectRelation,
|
||||||
|
"Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
|
||||||
|
mock.operations)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 5: Verify WHERE clause is cleared for recursive levels
|
||||||
|
t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
|
||||||
|
// Find the mastertaskitem preload - it should be marked as recursive
|
||||||
|
var recursivePreload common.PreloadOption
|
||||||
|
found := false
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||||
|
recursivePreload = preload
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||||
|
|
||||||
|
// The root level has a WHERE clause (rid_parentmastertaskitem is null)
|
||||||
|
// But when we apply recursion, it should be cleared
|
||||||
|
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// After the first level, WHERE clauses should not be reapplied
|
||||||
|
// We check that the recursive relation was created (which means WHERE was cleared internally)
|
||||||
|
foundRecursiveRelation := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundRecursiveRelation = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, foundRecursiveRelation,
|
||||||
|
"Recursive relation should be created (WHERE clause should be cleared internally)")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 6: Verify child relations are extended to recursive levels
|
||||||
|
t.Run("ChildRelationsExtended", func(t *testing.T) {
|
||||||
|
// Find the mastertaskitem preload - it should be marked as recursive
|
||||||
|
var recursivePreload common.PreloadOption
|
||||||
|
foundRecursive := false
|
||||||
|
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||||
|
recursivePreload = preload
|
||||||
|
foundRecursive = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||||
|
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
// actiondefinition should be extended to the recursive level
|
||||||
|
// Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
|
||||||
|
foundExtendedDEF := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||||
|
foundExtendedDEF = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, foundExtendedDEF,
|
||||||
|
"Expected actiondefinition relation to be extended to recursive level. Operations: %v",
|
||||||
|
mock.operations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8
|
||||||
|
func TestXFilesRecursivePreloadDepth(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
preload := common.PreloadOption{
|
||||||
|
Relation: "MAL",
|
||||||
|
Recursive: true,
|
||||||
|
RelatedKey: "rid_parentmastertaskitem",
|
||||||
|
}
|
||||||
|
|
||||||
|
allPreloads := []common.PreloadOption{preload}
|
||||||
|
|
||||||
|
t.Run("Depth7CreatesLevel8", func(t *testing.T) {
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
foundDepth8 := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundDepth8 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) {
|
||||||
|
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||||
|
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8)
|
||||||
|
mock := result.(*mockSelectQuery)
|
||||||
|
|
||||||
|
foundDepth9 := false
|
||||||
|
for _, op := range mock.operations {
|
||||||
|
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||||
|
foundDepth9 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestXFilesResponseStructure validates the actual structure of the response
|
||||||
|
// This test can be expanded when we have a full database integration test environment
|
||||||
|
func TestXFilesResponseStructure(t *testing.T) {
|
||||||
|
// Load the expected correct response
|
||||||
|
correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json")
|
||||||
|
correctData, err := os.ReadFile(correctResponsePath)
|
||||||
|
require.NoError(t, err, "Failed to read xfiles.response.correct.json")
|
||||||
|
|
||||||
|
var correctResponse []map[string]interface{}
|
||||||
|
err = json.Unmarshal(correctData, &correctResponse)
|
||||||
|
require.NoError(t, err, "Failed to parse xfiles.response.correct.json")
|
||||||
|
|
||||||
|
// Test 1: Verify root level has exactly 1 masterprocess
|
||||||
|
t.Run("RootLevelHasOneItem", func(t *testing.T) {
|
||||||
|
assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 2: Verify the root item has MTL relation
|
||||||
|
t.Run("RootHasMTLRelation", func(t *testing.T) {
|
||||||
|
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||||
|
rootItem := correctResponse[0]
|
||||||
|
|
||||||
|
mtl, exists := rootItem["MTL"]
|
||||||
|
assert.True(t, exists, "Root item should have MTL relation")
|
||||||
|
assert.NotNil(t, mtl, "MTL relation should not be null")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 3: Verify MTL has MAL items
|
||||||
|
t.Run("MTLHasMALItems", func(t *testing.T) {
|
||||||
|
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||||
|
rootItem := correctResponse[0]
|
||||||
|
|
||||||
|
mtl, ok := rootItem["MTL"].([]interface{})
|
||||||
|
require.True(t, ok, "MTL should be an array")
|
||||||
|
require.NotEmpty(t, mtl, "MTL should have items")
|
||||||
|
|
||||||
|
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MTL item should be a map")
|
||||||
|
|
||||||
|
mal, exists := firstMTL["MAL"]
|
||||||
|
assert.True(t, exists, "MTL item should have MAL relation")
|
||||||
|
assert.NotNil(t, mal, "MAL relation should not be null")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive)
|
||||||
|
t.Run("MALHasRecursiveRelation", func(t *testing.T) {
|
||||||
|
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||||
|
rootItem := correctResponse[0]
|
||||||
|
|
||||||
|
mtl, ok := rootItem["MTL"].([]interface{})
|
||||||
|
require.True(t, ok, "MTL should be an array")
|
||||||
|
require.NotEmpty(t, mtl, "MTL should have items")
|
||||||
|
|
||||||
|
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MTL item should be a map")
|
||||||
|
|
||||||
|
mal, ok := firstMTL["MAL"].([]interface{})
|
||||||
|
require.True(t, ok, "MAL should be an array")
|
||||||
|
require.NotEmpty(t, mal, "MAL should have items")
|
||||||
|
|
||||||
|
firstMAL, ok := mal[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MAL item should be a map")
|
||||||
|
|
||||||
|
// The key assertion: check for FK-based relation name
|
||||||
|
recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"]
|
||||||
|
assert.True(t, exists,
|
||||||
|
"MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)")
|
||||||
|
|
||||||
|
// It can be null or an array, depending on whether this item has children
|
||||||
|
if recursiveRelation != nil {
|
||||||
|
_, isArray := recursiveRelation.([]interface{})
|
||||||
|
assert.True(t, isArray,
|
||||||
|
"MAL_RID_PARENTMASTERTASKITEM should be an array when not null")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 5: Verify "Receive COB Document for" appears as a child, not at root
|
||||||
|
t.Run("ChildItemsAreNested", func(t *testing.T) {
|
||||||
|
// This test verifies that "Receive COB Document for" doesn't appear
|
||||||
|
// multiple times at the wrong level, but is properly nested
|
||||||
|
|
||||||
|
// Count how many times we find this description at the MAL level (should be 0 or 1)
|
||||||
|
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||||
|
rootItem := correctResponse[0]
|
||||||
|
|
||||||
|
mtl, ok := rootItem["MTL"].([]interface{})
|
||||||
|
require.True(t, ok, "MTL should be an array")
|
||||||
|
require.NotEmpty(t, mtl, "MTL should have items")
|
||||||
|
|
||||||
|
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MTL item should be a map")
|
||||||
|
|
||||||
|
mal, ok := firstMTL["MAL"].([]interface{})
|
||||||
|
require.True(t, ok, "MAL should be an array")
|
||||||
|
|
||||||
|
// Count root-level MAL items (before the fix, there were 12; should be 1)
|
||||||
|
assert.Len(t, mal, 1,
|
||||||
|
"MAL should have exactly 1 root-level item (before fix: 12 duplicates)")
|
||||||
|
|
||||||
|
// Verify the root item has a description
|
||||||
|
firstMAL, ok := mal[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MAL item should be a map")
|
||||||
|
|
||||||
|
description, exists := firstMAL["description"]
|
||||||
|
assert.True(t, exists, "MAL item should have a description")
|
||||||
|
assert.Equal(t, "Capture COB Information", description,
|
||||||
|
"Root MAL item should be 'Capture COB Information'")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 6: Verify DEF relation exists at MAL level
|
||||||
|
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||||
|
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||||
|
rootItem := correctResponse[0]
|
||||||
|
|
||||||
|
mtl, ok := rootItem["MTL"].([]interface{})
|
||||||
|
require.True(t, ok, "MTL should be an array")
|
||||||
|
require.NotEmpty(t, mtl, "MTL should have items")
|
||||||
|
|
||||||
|
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MTL item should be a map")
|
||||||
|
|
||||||
|
mal, ok := firstMTL["MAL"].([]interface{})
|
||||||
|
require.True(t, ok, "MAL should be an array")
|
||||||
|
require.NotEmpty(t, mal, "MAL should have items")
|
||||||
|
|
||||||
|
firstMAL, ok := mal[0].(map[string]interface{})
|
||||||
|
require.True(t, ok, "MAL item should be a map")
|
||||||
|
|
||||||
|
// Verify DEF relation exists (child relation extension)
|
||||||
|
def, exists := firstMAL["DEF"]
|
||||||
|
assert.True(t, exists, "MAL item should have DEF relation")
|
||||||
|
|
||||||
|
// DEF can be null or an object
|
||||||
|
if def != nil {
|
||||||
|
_, isMap := def.(map[string]interface{})
|
||||||
|
assert.True(t, isMap, "DEF should be an object when not null")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
527
pkg/security/OAUTH2.md
Normal file
527
pkg/security/OAUTH2.md
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
# OAuth2 Authentication Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Universal OAuth2 Support**: Works with any OAuth2 provider
|
||||||
|
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
|
||||||
|
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
|
||||||
|
- **Custom Providers**: Easy configuration for any OAuth2 service
|
||||||
|
- **Session Management**: Database-backed session storage
|
||||||
|
- **Token Refresh**: Automatic token refresh support
|
||||||
|
- **State Validation**: Built-in CSRF protection
|
||||||
|
- **User Auto-Creation**: Automatically creates users on first login
|
||||||
|
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Database Setup
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Run the schema from database_schema.sql
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
username VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
email VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
password VARCHAR(255),
|
||||||
|
user_level INTEGER DEFAULT 0,
|
||||||
|
roles VARCHAR(500),
|
||||||
|
is_active BOOLEAN DEFAULT true,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_login_at TIMESTAMP,
|
||||||
|
remote_id VARCHAR(255),
|
||||||
|
auth_provider VARCHAR(50)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
ip_address VARCHAR(45),
|
||||||
|
user_agent TEXT,
|
||||||
|
access_token TEXT,
|
||||||
|
refresh_token TEXT,
|
||||||
|
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||||
|
auth_provider VARCHAR(50)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- OAuth2 stored procedures (7 functions)
|
||||||
|
-- See database_schema.sql for full implementation
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Google OAuth2
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
// Create authenticator
|
||||||
|
oauth2Auth := security.NewGoogleAuthenticator(
|
||||||
|
"your-google-client-id",
|
||||||
|
"your-google-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Login route - redirects to Google
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Callback route - handles Google response
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. GitHub OAuth2
|
||||||
|
|
||||||
|
```go
|
||||||
|
oauth2Auth := security.NewGitHubAuthenticator(
|
||||||
|
"your-github-client-id",
|
||||||
|
"your-github-client-secret",
|
||||||
|
"http://localhost:8080/auth/github/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Same routes pattern as Google
|
||||||
|
router.HandleFunc("/auth/github/login", ...)
|
||||||
|
router.HandleFunc("/auth/github/callback", ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Microsoft OAuth2
|
||||||
|
|
||||||
|
```go
|
||||||
|
oauth2Auth := security.NewMicrosoftAuthenticator(
|
||||||
|
"your-microsoft-client-id",
|
||||||
|
"your-microsoft-client-secret",
|
||||||
|
"http://localhost:8080/auth/microsoft/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Facebook OAuth2
|
||||||
|
|
||||||
|
```go
|
||||||
|
oauth2Auth := security.NewFacebookAuthenticator(
|
||||||
|
"your-facebook-client-id",
|
||||||
|
"your-facebook-client-secret",
|
||||||
|
"http://localhost:8080/auth/facebook/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom OAuth2 Provider
|
||||||
|
|
||||||
|
```go
|
||||||
|
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: "your-client-id",
|
||||||
|
ClientSecret: "your-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||||
|
TokenURL: "https://your-provider.com/oauth/token",
|
||||||
|
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||||
|
DB: db,
|
||||||
|
ProviderName: "custom",
|
||||||
|
|
||||||
|
// Optional: Custom user info parser
|
||||||
|
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||||
|
return &security.UserContext{
|
||||||
|
UserName: userInfo["username"].(string),
|
||||||
|
Email: userInfo["email"].(string),
|
||||||
|
RemoteID: userInfo["id"].(string),
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
Claims: userInfo,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Protected Routes
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Create security provider
|
||||||
|
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||||
|
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||||
|
securityList, _ := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Apply middleware to protected routes
|
||||||
|
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||||
|
protectedRouter.Use(security.NewAuthMiddleware(securityList))
|
||||||
|
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
json.NewEncoder(w).Encode(userCtx)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Token Refresh
|
||||||
|
|
||||||
|
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Provider string `json:"provider"` // "google", "github", etc.
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Default to google if not specified
|
||||||
|
if req.Provider == "" {
|
||||||
|
req.Provider = "google"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use OAuth2-specific refresh method
|
||||||
|
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set new session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
|
||||||
|
- Store the refresh token securely on the client side
|
||||||
|
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
|
||||||
|
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
|
||||||
|
|
||||||
|
## Logout
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
|
||||||
|
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
|
||||||
|
Token: userCtx.SessionID,
|
||||||
|
UserID: userCtx.UserID,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: -1,
|
||||||
|
})
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-Provider Setup
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Single DatabaseAuthenticator with ALL OAuth2 providers
|
||||||
|
auth := security.NewDatabaseAuthenticator(db).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: "google-client-id",
|
||||||
|
ClientSecret: "google-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
ProviderName: "google",
|
||||||
|
}).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: "github-client-id",
|
||||||
|
ClientSecret: "github-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
UserInfoURL: "https://api.github.com/user",
|
||||||
|
ProviderName: "github",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get list of configured providers
|
||||||
|
providers := auth.OAuth2GetProviders() // ["google", "github"]
|
||||||
|
|
||||||
|
// Google routes
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
|
||||||
|
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
// ... handle response
|
||||||
|
})
|
||||||
|
|
||||||
|
// GitHub routes
|
||||||
|
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
|
||||||
|
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
// ... handle response
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use same authenticator for protected routes - works for ALL providers
|
||||||
|
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList, _ := security.NewSecurityList(provider)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### OAuth2Config Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| ClientID | string | OAuth2 client ID from provider |
|
||||||
|
| ClientSecret | string | OAuth2 client secret |
|
||||||
|
| RedirectURL | string | Callback URL registered with provider |
|
||||||
|
| Scopes | []string | OAuth2 scopes to request |
|
||||||
|
| AuthURL | string | Provider's authorization endpoint |
|
||||||
|
| TokenURL | string | Provider's token endpoint |
|
||||||
|
| UserInfoURL | string | Provider's user info endpoint |
|
||||||
|
| DB | *sql.DB | Database connection for sessions |
|
||||||
|
| UserInfoParser | func | Custom parser for user info (optional) |
|
||||||
|
| StateValidator | func | Custom state validator (optional) |
|
||||||
|
| ProviderName | string | Provider name for logging (optional) |
|
||||||
|
|
||||||
|
## User Info Parsing
|
||||||
|
|
||||||
|
The default parser extracts these standard fields:
|
||||||
|
- `sub` → RemoteID
|
||||||
|
- `email` → Email, UserName
|
||||||
|
- `name` → UserName
|
||||||
|
- `login` → UserName (GitHub)
|
||||||
|
|
||||||
|
Custom parser example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||||
|
// Extract custom fields
|
||||||
|
ctx := &security.UserContext{
|
||||||
|
UserName: userInfo["preferred_username"].(string),
|
||||||
|
Email: userInfo["email"].(string),
|
||||||
|
RemoteID: userInfo["sub"].(string),
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
Claims: userInfo, // Store all claims
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom roles based on provider data
|
||||||
|
if groups, ok := userInfo["groups"].([]interface{}); ok {
|
||||||
|
for _, g := range groups {
|
||||||
|
ctx.Roles = append(ctx.Roles, g.(string))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security Best Practices
|
||||||
|
|
||||||
|
1. **Always use HTTPS in production**
|
||||||
|
```go
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Secure: true, // Only send over HTTPS
|
||||||
|
HttpOnly: true, // Prevent XSS access
|
||||||
|
SameSite: http.SameSiteLaxMode, // CSRF protection
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Store secrets securely**
|
||||||
|
```go
|
||||||
|
clientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||||
|
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Validate redirect URLs**
|
||||||
|
- Only register trusted redirect URLs with OAuth2 providers
|
||||||
|
- Never accept redirect URL from request parameters
|
||||||
|
|
||||||
|
5. **Session expiration**
|
||||||
|
- OAuth2 sessions automatically expire based on token expiry
|
||||||
|
- Clean up expired sessions periodically:
|
||||||
|
```sql
|
||||||
|
DELETE FROM user_sessions WHERE expires_at < NOW();
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **State parameter**
|
||||||
|
- Automatically generated with cryptographic randomness
|
||||||
|
- One-time use and expires after 10 minutes
|
||||||
|
- Prevents CSRF attacks
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
All database operations use stored procedures for consistency and security:
|
||||||
|
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
|
||||||
|
- `resolvespec_oauth_createsession` - Create OAuth2 session
|
||||||
|
- `resolvespec_oauth_getsession` - Validate and retrieve session
|
||||||
|
- `resolvespec_oauth_deletesession` - Logout/delete session
|
||||||
|
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
|
||||||
|
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
|
||||||
|
- `resolvespec_oauth_getuser` - Get user data by ID
|
||||||
|
|
||||||
|
## Provider Setup Guides
|
||||||
|
|
||||||
|
### Google
|
||||||
|
|
||||||
|
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||||
|
2. Create a new project or select existing
|
||||||
|
3. Enable Google+ API
|
||||||
|
4. Create OAuth 2.0 credentials
|
||||||
|
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
|
||||||
|
6. Copy Client ID and Client Secret
|
||||||
|
|
||||||
|
### GitHub
|
||||||
|
|
||||||
|
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
|
||||||
|
2. Click "New OAuth App"
|
||||||
|
3. Set Homepage URL: `http://localhost:8080`
|
||||||
|
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
|
||||||
|
5. Copy Client ID and Client Secret
|
||||||
|
|
||||||
|
### Microsoft
|
||||||
|
|
||||||
|
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||||
|
2. Register new application in Azure AD
|
||||||
|
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
|
||||||
|
4. Create client secret
|
||||||
|
5. Copy Application (client) ID and secret value
|
||||||
|
|
||||||
|
### Facebook
|
||||||
|
|
||||||
|
1. Go to [Facebook Developers](https://developers.facebook.com/)
|
||||||
|
2. Create new app
|
||||||
|
3. Add Facebook Login product
|
||||||
|
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
|
||||||
|
5. Copy App ID and App Secret
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "redirect_uri_mismatch" error
|
||||||
|
- Ensure the redirect URL in code matches exactly with provider configuration
|
||||||
|
- Include protocol (http/https), domain, port, and path
|
||||||
|
|
||||||
|
### "invalid_client" error
|
||||||
|
- Verify Client ID and Client Secret are correct
|
||||||
|
- Check if credentials are for the correct environment (dev/prod)
|
||||||
|
|
||||||
|
### "invalid_grant" error during token exchange
|
||||||
|
- State parameter validation failed
|
||||||
|
- Token might have expired
|
||||||
|
- Check server time synchronization
|
||||||
|
|
||||||
|
### User not created after successful OAuth2 login
|
||||||
|
- Check database constraints (username/email unique)
|
||||||
|
- Verify UserInfoParser is extracting required fields
|
||||||
|
- Check database logs for constraint violations
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestOAuth2Flow(t *testing.T) {
|
||||||
|
// Mock database
|
||||||
|
db, mock, _ := sqlmock.New()
|
||||||
|
|
||||||
|
oauth2Auth := security.NewGoogleAuthenticator(
|
||||||
|
"test-client-id",
|
||||||
|
"test-client-secret",
|
||||||
|
"http://localhost/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test state generation
|
||||||
|
state, err := oauth2Auth.GenerateState()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, state)
|
||||||
|
|
||||||
|
// Test auth URL generation
|
||||||
|
authURL := oauth2Auth.GetAuthURL(state)
|
||||||
|
assert.Contains(t, authURL, "accounts.google.com")
|
||||||
|
assert.Contains(t, authURL, state)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### DatabaseAuthenticator with OAuth2
|
||||||
|
|
||||||
|
| Method | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
|
||||||
|
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
|
||||||
|
| OAuth2GenerateState() | Generates random state for CSRF protection |
|
||||||
|
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
|
||||||
|
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
|
||||||
|
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
|
||||||
|
| Login(ctx, req) | Standard username/password login |
|
||||||
|
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
|
||||||
|
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
|
||||||
|
|
||||||
|
### Pre-configured Constructors
|
||||||
|
|
||||||
|
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||||
|
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||||
|
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||||
|
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||||
|
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
|
||||||
|
|
||||||
|
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
|
||||||
|
|
||||||
|
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Complete working examples available in `oauth2_examples.go`:
|
||||||
|
- Basic Google OAuth2
|
||||||
|
- GitHub OAuth2
|
||||||
|
- Custom provider
|
||||||
|
- Multi-provider setup
|
||||||
|
- Token refresh
|
||||||
|
- Logout flow
|
||||||
|
- Complete integration with security middleware
|
||||||
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# OAuth2 Refresh Token - Quick Reference
|
||||||
|
|
||||||
|
## Quick Setup (3 Steps)
|
||||||
|
|
||||||
|
### 1. Initialize Authenticator
|
||||||
|
```go
|
||||||
|
auth := security.NewGoogleAuthenticator(
|
||||||
|
"client-id",
|
||||||
|
"client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. OAuth2 Login Flow
|
||||||
|
```go
|
||||||
|
// Login - Redirect to Google
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Callback - Store tokens
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, _ := auth.OAuth2HandleCallback(
|
||||||
|
r.Context(),
|
||||||
|
"google",
|
||||||
|
r.URL.Query().Get("code"),
|
||||||
|
r.URL.Query().Get("state"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Save refresh_token on client
|
||||||
|
// loginResp.RefreshToken - Store this securely!
|
||||||
|
// loginResp.Token - Session token for API calls
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Refresh Endpoint
|
||||||
|
```go
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Refresh token
|
||||||
|
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Multi-Provider Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Configure multiple providers
|
||||||
|
auth := security.NewDatabaseAuthenticator(db).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ProviderName: "google",
|
||||||
|
ClientID: "google-client-id",
|
||||||
|
ClientSecret: "google-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
}).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ProviderName: "github",
|
||||||
|
ClientID: "github-client-id",
|
||||||
|
ClientSecret: "github-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
UserInfoURL: "https://api.github.com/user",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Refresh with provider selection
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Provider string `json:"provider"` // "google" or "github"
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Client-Side JavaScript
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// Automatic token refresh on 401
|
||||||
|
async function apiCall(url) {
|
||||||
|
let response = await fetch(url, {
|
||||||
|
headers: {
|
||||||
|
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Token expired - refresh it
|
||||||
|
if (response.status === 401) {
|
||||||
|
await refreshToken();
|
||||||
|
|
||||||
|
// Retry request with new token
|
||||||
|
response = await fetch(url, {
|
||||||
|
headers: {
|
||||||
|
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function refreshToken() {
|
||||||
|
const response = await fetch('/auth/refresh', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
refresh_token: localStorage.getItem('refresh_token'),
|
||||||
|
provider: localStorage.getItem('provider')
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
localStorage.setItem('access_token', data.token);
|
||||||
|
localStorage.setItem('refresh_token', data.refresh_token);
|
||||||
|
} else {
|
||||||
|
// Refresh failed - redirect to login
|
||||||
|
window.location.href = '/login';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Methods
|
||||||
|
|
||||||
|
| Method | Parameters | Returns |
|
||||||
|
|--------|-----------|---------|
|
||||||
|
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
|
||||||
|
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
|
||||||
|
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
|
||||||
|
| `OAuth2GenerateState` | none | `string, error` |
|
||||||
|
| `OAuth2GetProviders` | none | `[]string` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## LoginResponse Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
type LoginResponse struct {
|
||||||
|
Token string // New session token for API calls
|
||||||
|
RefreshToken string // Refresh token (store securely)
|
||||||
|
User *UserContext // User information
|
||||||
|
ExpiresIn int64 // Seconds until token expires
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Database Stored Procedures
|
||||||
|
|
||||||
|
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
|
||||||
|
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
|
||||||
|
- `resolvespec_oauth_getuser(user_id)` - Get user data
|
||||||
|
|
||||||
|
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Errors
|
||||||
|
|
||||||
|
| Error | Cause | Solution |
|
||||||
|
|-------|-------|----------|
|
||||||
|
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
|
||||||
|
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
|
||||||
|
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Security Checklist
|
||||||
|
|
||||||
|
- [ ] Use HTTPS for all OAuth2 endpoints
|
||||||
|
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
|
||||||
|
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
|
||||||
|
- [ ] Implement rate limiting on refresh endpoint
|
||||||
|
- [ ] Log refresh attempts for audit
|
||||||
|
- [ ] Rotate tokens on refresh
|
||||||
|
- [ ] Revoke old sessions after successful refresh
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Login and get refresh token
|
||||||
|
curl http://localhost:8080/auth/google/login
|
||||||
|
# Follow OAuth2 flow, get refresh_token from callback response
|
||||||
|
|
||||||
|
# 2. Refresh token
|
||||||
|
curl -X POST http://localhost:8080/auth/refresh \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
|
||||||
|
|
||||||
|
# 3. Use new token
|
||||||
|
curl http://localhost:8080/api/protected \
|
||||||
|
-H "Authorization: Bearer sess_abc123..."
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pre-configured Providers
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Google
|
||||||
|
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||||
|
|
||||||
|
// GitHub
|
||||||
|
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
|
||||||
|
|
||||||
|
// Microsoft
|
||||||
|
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
|
||||||
|
|
||||||
|
// Facebook
|
||||||
|
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
|
||||||
|
|
||||||
|
// All providers at once
|
||||||
|
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
|
||||||
|
"google": {...},
|
||||||
|
"github": {...},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Provider-Specific Notes
|
||||||
|
|
||||||
|
### Google
|
||||||
|
- Add `access_type=offline` to get refresh token
|
||||||
|
- Add `prompt=consent` to force consent screen
|
||||||
|
```go
|
||||||
|
authURL += "&access_type=offline&prompt=consent"
|
||||||
|
```
|
||||||
|
|
||||||
|
### GitHub
|
||||||
|
- Refresh tokens not always provided
|
||||||
|
- May need to request `offline_access` scope
|
||||||
|
|
||||||
|
### Microsoft
|
||||||
|
- Use `offline_access` scope for refresh token
|
||||||
|
|
||||||
|
### Facebook
|
||||||
|
- Tokens expire after 60 days by default
|
||||||
|
- Check app settings for token expiration policy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
|
||||||
|
|
||||||
|
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.
|
||||||
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
# OAuth2 Refresh Token Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
|
||||||
|
|
||||||
|
## Implementation Status: ✅ COMPLETE
|
||||||
|
|
||||||
|
### Components Implemented
|
||||||
|
|
||||||
|
1. **✅ Database Schema** - Tables and stored procedures
|
||||||
|
2. **✅ Go Methods** - OAuth2RefreshToken implementation
|
||||||
|
3. **✅ Thread Safety** - Mutex protection for provider map
|
||||||
|
4. **✅ Examples** - Working code examples
|
||||||
|
5. **✅ Documentation** - Complete API reference
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Database Schema
|
||||||
|
|
||||||
|
### Tables Modified
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- user_sessions table with OAuth2 token fields
|
||||||
|
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
ip_address VARCHAR(45),
|
||||||
|
user_agent TEXT,
|
||||||
|
access_token TEXT, -- OAuth2 access token
|
||||||
|
refresh_token TEXT, -- OAuth2 refresh token
|
||||||
|
token_type VARCHAR(50), -- "Bearer", etc.
|
||||||
|
auth_provider VARCHAR(50) -- "google", "github", etc.
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stored Procedures
|
||||||
|
|
||||||
|
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
|
||||||
|
- Gets OAuth2 session data by refresh token
|
||||||
|
- Returns: `{user_id, access_token, token_type, expiry}`
|
||||||
|
- Location: `database_schema.sql:714`
|
||||||
|
|
||||||
|
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
|
||||||
|
- Updates session with new tokens after refresh
|
||||||
|
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
|
||||||
|
- Location: `database_schema.sql:752`
|
||||||
|
|
||||||
|
**`resolvespec_oauth_getuser(p_user_id)`**
|
||||||
|
- Gets user data by ID for building UserContext
|
||||||
|
- Location: `database_schema.sql:791`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Go Implementation
|
||||||
|
|
||||||
|
### Method Signature
|
||||||
|
|
||||||
|
```go
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
|
||||||
|
ctx context.Context,
|
||||||
|
refreshToken string,
|
||||||
|
providerName string,
|
||||||
|
) (*LoginResponse, error)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Location:** `pkg/security/oauth2_methods.go:375`
|
||||||
|
|
||||||
|
### Implementation Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Validate provider exists
|
||||||
|
├─ getOAuth2Provider(providerName) with RLock
|
||||||
|
└─ Return error if provider not configured
|
||||||
|
|
||||||
|
2. Get session from database
|
||||||
|
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
|
||||||
|
└─ Parse session data {user_id, access_token, token_type, expiry}
|
||||||
|
|
||||||
|
3. Refresh token with OAuth2 provider
|
||||||
|
├─ Create oauth2.Token from stored data
|
||||||
|
├─ Use provider.config.TokenSource(ctx, oldToken)
|
||||||
|
└─ Call tokenSource.Token() to get new token
|
||||||
|
|
||||||
|
4. Generate new session token
|
||||||
|
└─ Use OAuth2GenerateState() for secure random token
|
||||||
|
|
||||||
|
5. Update database
|
||||||
|
├─ Call resolvespec_oauth_updaterefreshtoken()
|
||||||
|
└─ Store new session_token, access_token, refresh_token
|
||||||
|
|
||||||
|
6. Get user data
|
||||||
|
├─ Call resolvespec_oauth_getuser(user_id)
|
||||||
|
└─ Build UserContext
|
||||||
|
|
||||||
|
7. Return LoginResponse
|
||||||
|
└─ {Token, RefreshToken, User, ExpiresIn}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Thread Safety
|
||||||
|
|
||||||
|
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
|
||||||
|
|
||||||
|
```go
|
||||||
|
type DatabaseAuthenticator struct {
|
||||||
|
oauth2Providers map[string]*OAuth2Provider
|
||||||
|
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read operations use RLock
|
||||||
|
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
|
||||||
|
a.oauth2ProvidersMutex.RLock()
|
||||||
|
defer a.oauth2ProvidersMutex.RUnlock()
|
||||||
|
// ... access map
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write operations use Lock
|
||||||
|
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
|
||||||
|
a.oauth2ProvidersMutex.Lock()
|
||||||
|
defer a.oauth2ProvidersMutex.Unlock()
|
||||||
|
// ... modify map
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Usage Examples
|
||||||
|
|
||||||
|
### Single Provider (Google)
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Create Google OAuth2 authenticator
|
||||||
|
auth := security.NewGoogleAuthenticator(
|
||||||
|
"your-client-id",
|
||||||
|
"your-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Token refresh endpoint
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Refresh token (provider name defaults to "google")
|
||||||
|
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set new session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Provider Setup
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Single authenticator with multiple OAuth2 providers
|
||||||
|
auth := security.NewDatabaseAuthenticator(db).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: "google-client-id",
|
||||||
|
ClientSecret: "google-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
ProviderName: "google",
|
||||||
|
}).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: "github-client-id",
|
||||||
|
ClientSecret: "github-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
UserInfoURL: "https://api.github.com/user",
|
||||||
|
ProviderName: "github",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Refresh endpoint with provider selection
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Provider string `json:"provider"` // "google" or "github"
|
||||||
|
}
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Refresh with specific provider
|
||||||
|
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Client-Side Usage
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// JavaScript client example
|
||||||
|
async function refreshAccessToken() {
|
||||||
|
const refreshToken = localStorage.getItem('refresh_token');
|
||||||
|
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
|
||||||
|
|
||||||
|
const response = await fetch('/auth/refresh', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
refresh_token: refreshToken,
|
||||||
|
provider: provider
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
// Store new tokens
|
||||||
|
localStorage.setItem('access_token', data.token);
|
||||||
|
localStorage.setItem('refresh_token', data.refresh_token);
|
||||||
|
|
||||||
|
console.log('Token refreshed successfully');
|
||||||
|
return data.token;
|
||||||
|
} else {
|
||||||
|
// Refresh failed - redirect to login
|
||||||
|
window.location.href = '/login';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Automatically refresh token when API returns 401
|
||||||
|
async function apiCall(endpoint) {
|
||||||
|
let response = await fetch(endpoint, {
|
||||||
|
headers: {
|
||||||
|
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status === 401) {
|
||||||
|
// Token expired - try refresh
|
||||||
|
const newToken = await refreshAccessToken();
|
||||||
|
|
||||||
|
// Retry with new token
|
||||||
|
response = await fetch(endpoint, {
|
||||||
|
headers: {
|
||||||
|
'Authorization': 'Bearer ' + newToken
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. API Reference
|
||||||
|
|
||||||
|
### DatabaseAuthenticator Methods
|
||||||
|
|
||||||
|
| Method | Signature | Description |
|
||||||
|
|--------|-----------|-------------|
|
||||||
|
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
|
||||||
|
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
|
||||||
|
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
|
||||||
|
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
|
||||||
|
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
|
||||||
|
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
|
||||||
|
|
||||||
|
### LoginResponse Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
type LoginResponse struct {
|
||||||
|
Token string // New session token
|
||||||
|
RefreshToken string // New refresh token (may be same as input)
|
||||||
|
User *UserContext // User information
|
||||||
|
ExpiresIn int64 // Seconds until expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserContext struct {
|
||||||
|
UserID int // Database user ID
|
||||||
|
UserName string // Username
|
||||||
|
Email string // Email address
|
||||||
|
UserLevel int // Permission level
|
||||||
|
SessionID string // Session token
|
||||||
|
RemoteID string // OAuth2 provider user ID
|
||||||
|
Roles []string // User roles
|
||||||
|
Claims map[string]any // Additional claims
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Important Notes
|
||||||
|
|
||||||
|
### Provider Configuration
|
||||||
|
|
||||||
|
**For Google:** Add `access_type=offline` to get refresh token on first login:
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
|
||||||
|
// When generating auth URL, add access_type parameter
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||||
|
authURL += "&access_type=offline&prompt=consent"
|
||||||
|
```
|
||||||
|
|
||||||
|
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
|
||||||
|
|
||||||
|
### Token Storage
|
||||||
|
|
||||||
|
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
|
||||||
|
- Never log refresh tokens
|
||||||
|
- Refresh tokens are long-lived (days/months depending on provider)
|
||||||
|
- Access tokens are short-lived (minutes/hours)
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
Common errors:
|
||||||
|
- `"invalid or expired refresh token"` - Token expired or revoked
|
||||||
|
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
|
||||||
|
- `"failed to refresh token with provider"` - Provider rejected refresh request
|
||||||
|
|
||||||
|
### Security Best Practices
|
||||||
|
|
||||||
|
1. **Always use HTTPS** for token transmission
|
||||||
|
2. **Store refresh tokens securely** on client
|
||||||
|
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
|
||||||
|
4. **Implement token rotation** - issue new refresh token on each refresh
|
||||||
|
5. **Revoke old tokens** after successful refresh
|
||||||
|
6. **Rate limit** refresh endpoints
|
||||||
|
7. **Log refresh attempts** for audit trail
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Testing
|
||||||
|
|
||||||
|
### Manual Test Flow
|
||||||
|
|
||||||
|
1. **Initial Login:**
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/auth/google/login
|
||||||
|
# Follow redirect to Google
|
||||||
|
# Returns to callback with LoginResponse containing refresh_token
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Wait for Token Expiry (or manually expire in DB)**
|
||||||
|
|
||||||
|
3. **Refresh Token:**
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/auth/refresh \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"refresh_token": "ya29.a0AfH6SMB...",
|
||||||
|
"provider": "google"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Response:
|
||||||
|
{
|
||||||
|
"token": "sess_abc123...",
|
||||||
|
"refresh_token": "ya29.a0AfH6SMB...",
|
||||||
|
"user": {
|
||||||
|
"user_id": 1,
|
||||||
|
"user_name": "john_doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
"session_id": "sess_abc123..."
|
||||||
|
},
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Use New Token:**
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/api/protected \
|
||||||
|
-H "Authorization: Bearer sess_abc123..."
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Verification
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Check session with refresh token
|
||||||
|
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
|
||||||
|
FROM user_sessions
|
||||||
|
WHERE refresh_token = 'ya29.a0AfH6SMB...';
|
||||||
|
|
||||||
|
-- Verify token was updated after refresh
|
||||||
|
SELECT session_token, access_token, refresh_token,
|
||||||
|
expires_at, last_activity_at
|
||||||
|
FROM user_sessions
|
||||||
|
WHERE user_id = 1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 1;
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Troubleshooting
|
||||||
|
|
||||||
|
### "Refresh token not found or expired"
|
||||||
|
|
||||||
|
**Cause:** Refresh token doesn't exist in database or session expired
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- Check if initial OAuth2 login stored refresh token
|
||||||
|
- Verify provider returns refresh token (some require `access_type=offline`)
|
||||||
|
- Check session hasn't been deleted from database
|
||||||
|
|
||||||
|
### "Failed to refresh token with provider"
|
||||||
|
|
||||||
|
**Cause:** OAuth2 provider rejected the refresh request
|
||||||
|
|
||||||
|
**Possible reasons:**
|
||||||
|
- Refresh token was revoked by user
|
||||||
|
- OAuth2 app credentials changed
|
||||||
|
- Network connectivity issues
|
||||||
|
- Provider rate limiting
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- Re-authenticate user (full OAuth2 flow)
|
||||||
|
- Check provider dashboard for app status
|
||||||
|
- Verify client credentials are correct
|
||||||
|
|
||||||
|
### "OAuth2 provider 'xxx' not found"
|
||||||
|
|
||||||
|
**Cause:** Provider not registered with `WithOAuth2()`
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```go
|
||||||
|
// Make sure provider is configured
|
||||||
|
auth := security.NewDatabaseAuthenticator(db).
|
||||||
|
WithOAuth2(security.OAuth2Config{
|
||||||
|
ProviderName: "google", // This name must match refresh call
|
||||||
|
// ... other config
|
||||||
|
})
|
||||||
|
|
||||||
|
// Then use same name in refresh
|
||||||
|
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Complete Working Example
|
||||||
|
|
||||||
|
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
OAuth2 refresh token functionality is **production-ready** with:
|
||||||
|
|
||||||
|
- ✅ Complete database schema with stored procedures
|
||||||
|
- ✅ Thread-safe Go implementation with mutex protection
|
||||||
|
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
|
||||||
|
- ✅ Comprehensive error handling
|
||||||
|
- ✅ Working code examples
|
||||||
|
- ✅ Full API documentation
|
||||||
|
- ✅ Security best practices implemented
|
||||||
|
|
||||||
|
**No additional implementation needed - feature is complete and functional.**
|
||||||
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# Passkey Authentication Quick Reference
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### Database Schema
|
||||||
|
Run the passkey SQL schema (in database_schema.sql):
|
||||||
|
- Creates `user_passkey_credentials` table
|
||||||
|
- Adds stored procedures for passkey operations
|
||||||
|
|
||||||
|
### Go Code
|
||||||
|
```go
|
||||||
|
// Create passkey provider
|
||||||
|
passkeyProvider := security.NewDatabasePasskeyProvider(db,
|
||||||
|
security.DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
RPOrigin: "https://example.com",
|
||||||
|
Timeout: 60000,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create authenticator with passkey support
|
||||||
|
auth := security.NewDatabaseAuthenticatorWithOptions(db,
|
||||||
|
security.DatabaseAuthenticatorOptions{
|
||||||
|
PasskeyProvider: passkeyProvider,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Or add passkey to existing authenticator
|
||||||
|
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Registration Flow
|
||||||
|
|
||||||
|
### Backend - Step 1: Begin Registration
|
||||||
|
```go
|
||||||
|
options, err := auth.BeginPasskeyRegistration(ctx,
|
||||||
|
security.PasskeyBeginRegistrationRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "alice",
|
||||||
|
DisplayName: "Alice Smith",
|
||||||
|
})
|
||||||
|
// Send options to client as JSON
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend - Step 2: Create Credential
|
||||||
|
```javascript
|
||||||
|
// Convert options from server
|
||||||
|
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||||
|
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||||
|
|
||||||
|
// Create credential
|
||||||
|
const credential = await navigator.credentials.create({
|
||||||
|
publicKey: options
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send credential back to server
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend - Step 3: Complete Registration
|
||||||
|
```go
|
||||||
|
credential, err := auth.CompletePasskeyRegistration(ctx,
|
||||||
|
security.PasskeyRegisterRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Response: clientResponse,
|
||||||
|
ExpectedChallenge: storedChallenge,
|
||||||
|
CredentialName: "My iPhone",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Authentication Flow
|
||||||
|
|
||||||
|
### Backend - Step 1: Begin Authentication
|
||||||
|
```go
|
||||||
|
options, err := auth.BeginPasskeyAuthentication(ctx,
|
||||||
|
security.PasskeyBeginAuthenticationRequest{
|
||||||
|
Username: "alice", // Optional for resident key
|
||||||
|
})
|
||||||
|
// Send options to client as JSON
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend - Step 2: Get Credential
|
||||||
|
```javascript
|
||||||
|
// Convert options from server
|
||||||
|
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||||
|
|
||||||
|
// Get credential
|
||||||
|
const credential = await navigator.credentials.get({
|
||||||
|
publicKey: options
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send assertion back to server
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend - Step 3: Complete Authentication
|
||||||
|
```go
|
||||||
|
loginResponse, err := auth.LoginWithPasskey(ctx,
|
||||||
|
security.PasskeyLoginRequest{
|
||||||
|
Response: clientAssertion,
|
||||||
|
ExpectedChallenge: storedChallenge,
|
||||||
|
Claims: map[string]any{
|
||||||
|
"ip_address": "192.168.1.1",
|
||||||
|
"user_agent": "Mozilla/5.0...",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
// Returns session token and user info
|
||||||
|
```
|
||||||
|
|
||||||
|
## Credential Management
|
||||||
|
|
||||||
|
### List Credentials
|
||||||
|
```go
|
||||||
|
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update Credential Name
|
||||||
|
```go
|
||||||
|
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Delete Credential
|
||||||
|
```go
|
||||||
|
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP Endpoints Example
|
||||||
|
|
||||||
|
### POST /api/passkey/register/begin
|
||||||
|
Request: `{user_id, username, display_name}`
|
||||||
|
Response: PasskeyRegistrationOptions
|
||||||
|
|
||||||
|
### POST /api/passkey/register/complete
|
||||||
|
Request: `{user_id, response, credential_name}`
|
||||||
|
Response: PasskeyCredential
|
||||||
|
|
||||||
|
### POST /api/passkey/login/begin
|
||||||
|
Request: `{username}` (optional)
|
||||||
|
Response: PasskeyAuthenticationOptions
|
||||||
|
|
||||||
|
### POST /api/passkey/login/complete
|
||||||
|
Request: `{response}`
|
||||||
|
Response: LoginResponse with session token
|
||||||
|
|
||||||
|
### GET /api/passkey/credentials
|
||||||
|
Response: Array of PasskeyCredential
|
||||||
|
|
||||||
|
### DELETE /api/passkey/credentials/{id}
|
||||||
|
Request: `{credential_id}`
|
||||||
|
Response: 204 No Content
|
||||||
|
|
||||||
|
## Database Stored Procedures
|
||||||
|
|
||||||
|
- `resolvespec_passkey_store_credential` - Store new credential
|
||||||
|
- `resolvespec_passkey_get_credential` - Get credential by ID
|
||||||
|
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
|
||||||
|
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
|
||||||
|
- `resolvespec_passkey_delete_credential` - Delete credential
|
||||||
|
- `resolvespec_passkey_update_name` - Update credential name
|
||||||
|
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
|
||||||
|
|
||||||
|
## Security Features
|
||||||
|
|
||||||
|
- **Clone Detection**: Sign counter validation detects credential cloning
|
||||||
|
- **Attestation Support**: Stores attestation type (none, indirect, direct)
|
||||||
|
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
|
||||||
|
- **Backup State**: Tracks if credential is backed up/synced
|
||||||
|
- **User Verification**: Supports preferred/required user verification
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
|
||||||
|
|
||||||
|
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
|
||||||
|
|
||||||
|
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
|
||||||
|
|
||||||
|
4. **Browser Support**: Check browser compatibility for WebAuthn API.
|
||||||
|
|
||||||
|
5. **Relying Party ID**: Must match your domain exactly.
|
||||||
|
|
||||||
|
## Client-Side Helper Functions
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
function base64ToArrayBuffer(base64) {
|
||||||
|
const binary = atob(base64);
|
||||||
|
const bytes = new Uint8Array(binary.length);
|
||||||
|
for (let i = 0; i < binary.length; i++) {
|
||||||
|
bytes[i] = binary.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return bytes.buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
function arrayBufferToBase64(buffer) {
|
||||||
|
const bytes = new Uint8Array(buffer);
|
||||||
|
let binary = '';
|
||||||
|
for (let i = 0; i < bytes.length; i++) {
|
||||||
|
binary += String.fromCharCode(bytes[i]);
|
||||||
|
}
|
||||||
|
return btoa(binary);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run tests: `go test -v ./pkg/security -run Passkey`
|
||||||
|
|
||||||
|
All passkey functionality includes comprehensive tests using sqlmock.
|
||||||
@@ -7,15 +7,16 @@
|
|||||||
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
|
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
|
||||||
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
|
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
|
||||||
// OR: auth := security.NewHeaderAuthenticator()
|
// OR: auth := security.NewHeaderAuthenticator()
|
||||||
|
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
|
||||||
|
|
||||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||||
|
|
||||||
// Step 2: Combine providers
|
// Step 2: Combine providers
|
||||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
|
||||||
// Step 3: Setup and apply middleware
|
// Step 3: Setup and apply middleware
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList, _ := security.SetupSecurityProvider(handler, provider)
|
||||||
router.Use(security.NewAuthMiddleware(securityList))
|
router.Use(security.NewAuthMiddleware(securityList))
|
||||||
router.Use(security.SetSecurityMiddleware(securityList))
|
router.Use(security.SetSecurityMiddleware(securityList))
|
||||||
```
|
```
|
||||||
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
|
|||||||
```go
|
```go
|
||||||
// DatabaseAuthenticator uses these stored procedures:
|
// DatabaseAuthenticator uses these stored procedures:
|
||||||
resolvespec_login(jsonb) // Login with credentials
|
resolvespec_login(jsonb) // Login with credentials
|
||||||
|
resolvespec_register(jsonb) // Register new user
|
||||||
resolvespec_logout(jsonb) // Invalidate session
|
resolvespec_logout(jsonb) // Invalidate session
|
||||||
resolvespec_session(text, text) // Validate session token
|
resolvespec_session(text, text) // Validate session token
|
||||||
resolvespec_session_update(text, jsonb) // Update activity timestamp
|
resolvespec_session_update(text, jsonb) // Update activity timestamp
|
||||||
@@ -502,10 +504,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Login/Logout Endpoints
|
## Login/Logout/Register Endpoints
|
||||||
|
|
||||||
```go
|
```go
|
||||||
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
|
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
|
||||||
|
// Register
|
||||||
|
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req security.RegisterRequest
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Check if provider supports registration
|
||||||
|
registrable, ok := securityList.Provider().(security.Registrable)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Registration not supported", http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := registrable.Register(r.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}).Methods("POST")
|
||||||
|
|
||||||
// Login
|
// Login
|
||||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req security.LoginRequest
|
var req security.LoginRequest
|
||||||
@@ -707,6 +730,7 @@ meta, ok := security.GetUserMeta(ctx)
|
|||||||
| File | Description |
|
| File | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
|
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||||
|
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
|
||||||
| `examples.go` | Working provider implementations to copy |
|
| `examples.go` | Working provider implementations to copy |
|
||||||
| `setup_example.go` | 6 complete integration examples |
|
| `setup_example.go` | 6 complete integration examples |
|
||||||
| `README.md` | Architecture overview and migration guide |
|
| `README.md` | Architecture overview and migration guide |
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
|||||||
|
|
||||||
- ✅ **Interface-Based** - Type-safe providers instead of callbacks
|
- ✅ **Interface-Based** - Type-safe providers instead of callbacks
|
||||||
- ✅ **Login/Logout Support** - Built-in authentication lifecycle
|
- ✅ **Login/Logout Support** - Built-in authentication lifecycle
|
||||||
|
- ✅ **Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
|
||||||
- ✅ **Composable** - Mix and match different providers
|
- ✅ **Composable** - Mix and match different providers
|
||||||
- ✅ **No Global State** - Each handler has its own security configuration
|
- ✅ **No Global State** - Each handler has its own security configuration
|
||||||
- ✅ **Testable** - Easy to mock and test
|
- ✅ **Testable** - Easy to mock and test
|
||||||
@@ -212,6 +213,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
|
|||||||
// Note: Requires JWT library installation for token signing/verification
|
// Note: Requires JWT library installation for token signing/verification
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
|
||||||
|
```go
|
||||||
|
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||||
|
|
||||||
|
// Use in-memory provider (for testing)
|
||||||
|
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
|
||||||
|
// Or use database provider (for production)
|
||||||
|
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
// Requires: users table with totp fields, user_totp_backup_codes table
|
||||||
|
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
|
||||||
|
|
||||||
|
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||||
|
// Supports: TOTP codes, backup codes, QR code generation
|
||||||
|
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
|
||||||
|
```
|
||||||
|
|
||||||
### Column Security Providers
|
### Column Security Providers
|
||||||
|
|
||||||
**DatabaseColumnSecurityProvider** - Loads rules from database:
|
**DatabaseColumnSecurityProvider** - Loads rules from database:
|
||||||
@@ -334,7 +352,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Two-Factor Authentication (2FA)
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
- **Optional per-user** - Enable/disable 2FA individually
|
||||||
|
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
|
||||||
|
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
|
||||||
|
- **Backup codes** - One-time recovery codes with secure hashing
|
||||||
|
- **Clock skew** - Handles time differences between client/server
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 1. Wrap existing authenticator with 2FA support
|
||||||
|
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||||
|
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||||
|
|
||||||
|
// 2. Use as normal authenticator
|
||||||
|
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Enable 2FA for User
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 1. Initiate 2FA setup
|
||||||
|
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
|
||||||
|
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
|
||||||
|
|
||||||
|
// 2. User scans QR code with authenticator app
|
||||||
|
// Display secret.QRCodeURL as QR code image
|
||||||
|
|
||||||
|
// 3. User enters verification code from app
|
||||||
|
code := "123456" // From authenticator app
|
||||||
|
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
|
||||||
|
// 2FA is now enabled for this user
|
||||||
|
|
||||||
|
// 4. Store backup codes securely and show to user once
|
||||||
|
// Display: secret.BackupCodes (10 codes)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Login Flow with 2FA
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 1. User provides credentials
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "user@example.com",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(ctx, req)
|
||||||
|
|
||||||
|
// 2. Check if 2FA required
|
||||||
|
if resp.Requires2FA {
|
||||||
|
// Prompt user for 2FA code
|
||||||
|
code := getUserInput() // From authenticator app or backup code
|
||||||
|
|
||||||
|
// 3. Login again with 2FA code
|
||||||
|
req.TwoFactorCode = code
|
||||||
|
resp, err = tfaAuth.Login(ctx, req)
|
||||||
|
|
||||||
|
// 4. Success - token is returned
|
||||||
|
token := resp.Token
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manage 2FA
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Disable 2FA
|
||||||
|
err := tfaAuth.Disable2FA(userID)
|
||||||
|
|
||||||
|
// Regenerate backup codes
|
||||||
|
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
|
||||||
|
|
||||||
|
// Check status
|
||||||
|
has2FA, err := tfaProvider.Get2FAStatus(userID)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom 2FA Storage
|
||||||
|
|
||||||
|
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Uses PostgreSQL stored procedures for all operations
|
||||||
|
db := setupDatabase()
|
||||||
|
|
||||||
|
// Run migrations from totp_database_schema.sql
|
||||||
|
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
|
||||||
|
// - Create user_totp_backup_codes table
|
||||||
|
// - Create resolvespec_totp_* stored procedures
|
||||||
|
|
||||||
|
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2: Implement Custom Provider**
|
||||||
|
|
||||||
|
Implement `TwoFactorAuthProvider` for custom storage:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type DBTwoFactorProvider struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||||
|
// Store secret and hashed backup codes in database
|
||||||
|
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
|
||||||
|
secret, hashCodes(backupCodes), userID).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||||
|
var secret string
|
||||||
|
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
|
||||||
|
return secret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
|
||||||
|
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
config := &security.TwoFactorConfig{
|
||||||
|
Algorithm: "SHA256", // SHA1, SHA256, SHA512
|
||||||
|
Digits: 8, // 6 or 8
|
||||||
|
Period: 30, // Seconds per code
|
||||||
|
SkewWindow: 2, // Accept codes ±2 periods
|
||||||
|
}
|
||||||
|
|
||||||
|
totp := security.NewTOTPGenerator(config)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Response Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
// LoginResponse with 2FA
|
||||||
|
type LoginResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Requires2FA bool `json:"requires_2fa"`
|
||||||
|
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
|
||||||
|
User *UserContext `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TwoFactorSecret for setup
|
||||||
|
type TwoFactorSecret struct {
|
||||||
|
Secret string `json:"secret"` // Base32 encoded
|
||||||
|
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
|
||||||
|
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserContext includes 2FA status
|
||||||
|
type UserContext struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
TwoFactorEnabled bool `json:"two_factor_enabled"`
|
||||||
|
// ... other fields
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Security Best Practices
|
||||||
|
|
||||||
|
- **Store secrets encrypted** - Never store TOTP secrets in plain text
|
||||||
|
- **Hash backup codes** - Use SHA-256 before storing
|
||||||
|
- **Rate limit** - Limit 2FA verification attempts
|
||||||
|
- **Require password** - Always verify password before disabling 2FA
|
||||||
|
- **Show backup codes once** - Display only during setup/regeneration
|
||||||
|
- **Log 2FA events** - Track enable/disable/failed attempts
|
||||||
|
- **Mark codes as used** - Backup codes are single-use only
|
||||||
|
|
||||||
|
|
||||||
json.NewEncoder(w).Encode(resp)
|
json.NewEncoder(w).Encode(resp)
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
|
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -17,22 +17,37 @@ type UserContext struct {
|
|||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Claims map[string]any `json:"claims"`
|
Claims map[string]any `json:"claims"`
|
||||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||||
|
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginRequest contains credentials for login
|
// LoginRequest contains credentials for login
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
|
||||||
Claims map[string]any `json:"claims"` // Additional login data
|
Claims map[string]any `json:"claims"` // Additional login data
|
||||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterRequest contains information for new user registration
|
||||||
|
type RegisterRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
UserLevel int `json:"user_level"`
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
Claims map[string]any `json:"claims"` // Additional registration data
|
||||||
|
Meta map[string]any `json:"meta"` // Additional metadata
|
||||||
|
}
|
||||||
|
|
||||||
// LoginResponse contains the result of a login attempt
|
// LoginResponse contains the result of a login attempt
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
User *UserContext `json:"user"`
|
User *UserContext `json:"user"`
|
||||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||||
|
Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
|
||||||
|
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
|
||||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,6 +70,12 @@ type Authenticator interface {
|
|||||||
Authenticate(r *http.Request) (*UserContext, error)
|
Authenticate(r *http.Request) (*UserContext, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Registrable allows providers to support user registration
|
||||||
|
type Registrable interface {
|
||||||
|
// Register creates a new user account
|
||||||
|
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// ColumnSecurityProvider handles column-level security (masking/hiding)
|
// ColumnSecurityProvider handles column-level security (masking/hiding)
|
||||||
type ColumnSecurityProvider interface {
|
type ColumnSecurityProvider interface {
|
||||||
// GetColumnSecurity loads column security rules for a user and entity
|
// GetColumnSecurity loads column security rules for a user and entity
|
||||||
|
|||||||
615
pkg/security/oauth2_examples.go
Normal file
615
pkg/security/oauth2_examples.go
Normal file
@@ -0,0 +1,615 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example: OAuth2 Authentication with Google
|
||||||
|
func ExampleOAuth2Google() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Create OAuth2 authenticator for Google
|
||||||
|
oauth2Auth := NewGoogleAuthenticator(
|
||||||
|
"your-client-id",
|
||||||
|
"your-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Login endpoint - redirects to Google
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Callback endpoint - handles Google response
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Return user info as JSON
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: OAuth2 Authentication with GitHub
|
||||||
|
func ExampleOAuth2GitHub() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
oauth2Auth := NewGitHubAuthenticator(
|
||||||
|
"your-github-client-id",
|
||||||
|
"your-github-client-secret",
|
||||||
|
"http://localhost:8080/auth/github/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Custom OAuth2 Provider
|
||||||
|
func ExampleOAuth2Custom() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Custom OAuth2 provider configuration
|
||||||
|
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: "your-client-id",
|
||||||
|
ClientSecret: "your-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||||
|
TokenURL: "https://your-provider.com/oauth/token",
|
||||||
|
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||||
|
ProviderName: "custom-provider",
|
||||||
|
|
||||||
|
// Custom user info parser
|
||||||
|
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
|
||||||
|
// Extract custom fields from your provider
|
||||||
|
return &UserContext{
|
||||||
|
UserName: userInfo["username"].(string),
|
||||||
|
Email: userInfo["email"].(string),
|
||||||
|
RemoteID: userInfo["id"].(string),
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
Claims: userInfo,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Multi-Provider OAuth2 with Security Integration
|
||||||
|
func ExampleOAuth2MultiProvider() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Create OAuth2 authenticators for multiple providers
|
||||||
|
googleAuth := NewGoogleAuthenticator(
|
||||||
|
"google-client-id",
|
||||||
|
"google-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
githubAuth := NewGitHubAuthenticator(
|
||||||
|
"github-client-id",
|
||||||
|
"github-client-secret",
|
||||||
|
"http://localhost:8080/auth/github/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create column and row security providers
|
||||||
|
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Google OAuth2 routes
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := googleAuth.OAuth2GenerateState()
|
||||||
|
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// GitHub OAuth2 routes
|
||||||
|
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := githubAuth.OAuth2GenerateState()
|
||||||
|
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use Google auth for protected routes (or GitHub - both work)
|
||||||
|
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
|
||||||
|
securityList, _ := NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Protected route with authentication
|
||||||
|
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||||
|
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||||
|
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := GetUserContext(r.Context())
|
||||||
|
_ = json.NewEncoder(w).Encode(userCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: OAuth2 with Token Refresh
|
||||||
|
func ExampleOAuth2TokenRefresh() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
oauth2Auth := NewGoogleAuthenticator(
|
||||||
|
"your-client-id",
|
||||||
|
"your-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Refresh token endpoint
|
||||||
|
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Provider string `json:"provider"` // "google", "github", etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to google if not specified
|
||||||
|
if req.Provider == "" {
|
||||||
|
req.Provider = "google"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use OAuth2-specific refresh method
|
||||||
|
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set new session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: OAuth2 Logout
|
||||||
|
func ExampleOAuth2Logout() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
oauth2Auth := NewGoogleAuthenticator(
|
||||||
|
"your-client-id",
|
||||||
|
"your-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
token := r.Header.Get("Authorization")
|
||||||
|
if token == "" {
|
||||||
|
cookie, err := r.Cookie("session_token")
|
||||||
|
if err == nil {
|
||||||
|
token = cookie.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != "" {
|
||||||
|
// Get user ID from session
|
||||||
|
userCtx, err := oauth2Auth.Authenticate(r)
|
||||||
|
if err == nil {
|
||||||
|
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||||
|
Token: token,
|
||||||
|
UserID: userCtx.UserID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("Logged out successfully"))
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Complete OAuth2 Integration with Database Setup
|
||||||
|
func ExampleOAuth2Complete() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Create tables (run once)
|
||||||
|
setupOAuth2Tables(db)
|
||||||
|
|
||||||
|
// Create OAuth2 authenticator
|
||||||
|
oauth2Auth := NewGoogleAuthenticator(
|
||||||
|
"your-client-id",
|
||||||
|
"your-client-secret",
|
||||||
|
"http://localhost:8080/auth/google/callback",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create security providers
|
||||||
|
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||||
|
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||||
|
securityList, _ := NewSecurityList(provider)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Public routes
|
||||||
|
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: int(loginResp.ExpiresIn),
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Protected routes
|
||||||
|
protectedRouter := router.PathPrefix("/").Subrouter()
|
||||||
|
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||||
|
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := GetUserContext(r.Context())
|
||||||
|
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
|
||||||
|
})
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := GetUserContext(r.Context())
|
||||||
|
_ = json.NewEncoder(w).Encode(userCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := GetUserContext(r.Context())
|
||||||
|
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||||
|
Token: userCtx.SessionID,
|
||||||
|
UserID: userCtx.UserID,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupOAuth2Tables(db *sql.DB) {
|
||||||
|
// Create tables from database_schema.sql
|
||||||
|
// This is a helper function - in production, use migrations
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create users table if not exists
|
||||||
|
_, _ = db.ExecContext(ctx, `
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
username VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
email VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
password VARCHAR(255),
|
||||||
|
user_level INTEGER DEFAULT 0,
|
||||||
|
roles VARCHAR(500),
|
||||||
|
is_active BOOLEAN DEFAULT true,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_login_at TIMESTAMP,
|
||||||
|
remote_id VARCHAR(255),
|
||||||
|
auth_provider VARCHAR(50)
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Create user_sessions table (used for both regular and OAuth2 sessions)
|
||||||
|
_, _ = db.ExecContext(ctx, `
|
||||||
|
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
ip_address VARCHAR(45),
|
||||||
|
user_agent TEXT,
|
||||||
|
access_token TEXT,
|
||||||
|
refresh_token TEXT,
|
||||||
|
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||||
|
auth_provider VARCHAR(50)
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: All OAuth2 Providers at Once
|
||||||
|
func ExampleOAuth2AllProviders() {
|
||||||
|
db, _ := sql.Open("postgres", "connection-string")
|
||||||
|
|
||||||
|
// Create authenticator with ALL OAuth2 providers
|
||||||
|
auth := NewDatabaseAuthenticator(db).
|
||||||
|
WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: "google-client-id",
|
||||||
|
ClientSecret: "google-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
ProviderName: "google",
|
||||||
|
}).
|
||||||
|
WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: "github-client-id",
|
||||||
|
ClientSecret: "github-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
UserInfoURL: "https://api.github.com/user",
|
||||||
|
ProviderName: "github",
|
||||||
|
}).
|
||||||
|
WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: "microsoft-client-id",
|
||||||
|
ClientSecret: "microsoft-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||||
|
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||||
|
ProviderName: "microsoft",
|
||||||
|
}).
|
||||||
|
WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: "facebook-client-id",
|
||||||
|
ClientSecret: "facebook-client-secret",
|
||||||
|
RedirectURL: "http://localhost:8080/auth/facebook/callback",
|
||||||
|
Scopes: []string{"email"},
|
||||||
|
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||||
|
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||||
|
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||||
|
ProviderName: "facebook",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get list of configured providers
|
||||||
|
providers := auth.OAuth2GetProviders()
|
||||||
|
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Google routes
|
||||||
|
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// GitHub routes
|
||||||
|
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Microsoft routes
|
||||||
|
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Facebook routes
|
||||||
|
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, _ := auth.OAuth2GenerateState()
|
||||||
|
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
})
|
||||||
|
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create security list for protected routes
|
||||||
|
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||||
|
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList, _ := NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Protected routes work for ALL OAuth2 providers + regular sessions
|
||||||
|
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||||
|
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||||
|
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||||
|
|
||||||
|
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := GetUserContext(r.Context())
|
||||||
|
_ = json.NewEncoder(w).Encode(userCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
579
pkg/security/oauth2_methods.go
Normal file
579
pkg/security/oauth2_methods.go
Normal file
@@ -0,0 +1,579 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuth2Config contains configuration for OAuth2 authentication
|
||||||
|
type OAuth2Config struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
RedirectURL string
|
||||||
|
Scopes []string
|
||||||
|
AuthURL string
|
||||||
|
TokenURL string
|
||||||
|
UserInfoURL string
|
||||||
|
ProviderName string
|
||||||
|
|
||||||
|
// Optional: Custom user info parser
|
||||||
|
// If not provided, will use standard claims (sub, email, name)
|
||||||
|
UserInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2Provider holds configuration and state for a single OAuth2 provider
|
||||||
|
type OAuth2Provider struct {
|
||||||
|
config *oauth2.Config
|
||||||
|
userInfoURL string
|
||||||
|
userInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||||
|
providerName string
|
||||||
|
states map[string]time.Time // state -> expiry time
|
||||||
|
statesMutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOAuth2 configures OAuth2 support for the DatabaseAuthenticator
|
||||||
|
// Can be called multiple times to add multiple OAuth2 providers
|
||||||
|
// Returns the same DatabaseAuthenticator instance for method chaining
|
||||||
|
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) *DatabaseAuthenticator {
|
||||||
|
if cfg.ProviderName == "" {
|
||||||
|
cfg.ProviderName = "oauth2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.UserInfoParser == nil {
|
||||||
|
cfg.UserInfoParser = defaultOAuth2UserInfoParser
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := &OAuth2Provider{
|
||||||
|
config: &oauth2.Config{
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
ClientSecret: cfg.ClientSecret,
|
||||||
|
RedirectURL: cfg.RedirectURL,
|
||||||
|
Scopes: cfg.Scopes,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: cfg.AuthURL,
|
||||||
|
TokenURL: cfg.TokenURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
userInfoURL: cfg.UserInfoURL,
|
||||||
|
userInfoParser: cfg.UserInfoParser,
|
||||||
|
providerName: cfg.ProviderName,
|
||||||
|
states: make(map[string]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize providers map if needed
|
||||||
|
a.oauth2ProvidersMutex.Lock()
|
||||||
|
if a.oauth2Providers == nil {
|
||||||
|
a.oauth2Providers = make(map[string]*OAuth2Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register provider
|
||||||
|
a.oauth2Providers[cfg.ProviderName] = provider
|
||||||
|
a.oauth2ProvidersMutex.Unlock()
|
||||||
|
|
||||||
|
// Start state cleanup goroutine for this provider
|
||||||
|
go provider.cleanupStates()
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2GetAuthURL returns the OAuth2 authorization URL for redirecting users
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2GetAuthURL(providerName, state string) (string, error) {
|
||||||
|
provider, err := a.getOAuth2Provider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store state for validation
|
||||||
|
provider.statesMutex.Lock()
|
||||||
|
provider.states[state] = time.Now().Add(10 * time.Minute)
|
||||||
|
provider.statesMutex.Unlock()
|
||||||
|
|
||||||
|
return provider.config.AuthCodeURL(state), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2GenerateState generates a random state string for CSRF protection
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2GenerateState() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2HandleCallback handles the OAuth2 callback and exchanges code for token
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2HandleCallback(ctx context.Context, providerName, code, state string) (*LoginResponse, error) {
|
||||||
|
provider, err := a.getOAuth2Provider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate state
|
||||||
|
if !provider.validateState(state) {
|
||||||
|
return nil, fmt.Errorf("invalid state parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for token
|
||||||
|
token, err := provider.config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch user info
|
||||||
|
client := provider.config.Client(ctx, token)
|
||||||
|
resp, err := client.Get(provider.userInfoURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch user info: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read user info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var userInfo map[string]any
|
||||||
|
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse user info
|
||||||
|
userCtx, err := provider.userInfoParser(userInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get or create user in database
|
||||||
|
userID, err := a.oauth2GetOrCreateUser(ctx, userCtx, providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get or create user: %w", err)
|
||||||
|
}
|
||||||
|
userCtx.UserID = userID
|
||||||
|
|
||||||
|
// Create session token
|
||||||
|
sessionToken, err := a.OAuth2GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate session token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(24 * time.Hour)
|
||||||
|
if token.Expiry.After(time.Now()) {
|
||||||
|
expiresAt = token.Expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store session in database
|
||||||
|
err = a.oauth2CreateSession(ctx, sessionToken, userCtx.UserID, token, expiresAt, providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userCtx.SessionID = sessionToken
|
||||||
|
|
||||||
|
return &LoginResponse{
|
||||||
|
Token: sessionToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
User: userCtx,
|
||||||
|
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2GetProviders returns list of configured OAuth2 provider names
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2GetProviders() []string {
|
||||||
|
a.oauth2ProvidersMutex.RLock()
|
||||||
|
defer a.oauth2ProvidersMutex.RUnlock()
|
||||||
|
|
||||||
|
if a.oauth2Providers == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := make([]string, 0, len(a.oauth2Providers))
|
||||||
|
for name := range a.oauth2Providers {
|
||||||
|
providers = append(providers, name)
|
||||||
|
}
|
||||||
|
return providers
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOAuth2Provider retrieves a registered OAuth2 provider by name
|
||||||
|
func (a *DatabaseAuthenticator) getOAuth2Provider(providerName string) (*OAuth2Provider, error) {
|
||||||
|
a.oauth2ProvidersMutex.RLock()
|
||||||
|
defer a.oauth2ProvidersMutex.RUnlock()
|
||||||
|
|
||||||
|
if a.oauth2Providers == nil {
|
||||||
|
return nil, fmt.Errorf("OAuth2 not configured - call WithOAuth2() first")
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := a.oauth2Providers[providerName]
|
||||||
|
if !ok {
|
||||||
|
// Build provider list without calling OAuth2GetProviders to avoid recursion
|
||||||
|
providerNames := make([]string, 0, len(a.oauth2Providers))
|
||||||
|
for name := range a.oauth2Providers {
|
||||||
|
providerNames = append(providerNames, name)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("OAuth2 provider '%s' not found - available providers: %v", providerName, providerNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauth2GetOrCreateUser finds or creates a user based on OAuth2 info using stored procedure
|
||||||
|
func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userCtx *UserContext, providerName string) (int, error) {
|
||||||
|
userData := map[string]interface{}{
|
||||||
|
"username": userCtx.UserName,
|
||||||
|
"email": userCtx.Email,
|
||||||
|
"remote_id": userCtx.RemoteID,
|
||||||
|
"user_level": userCtx.UserLevel,
|
||||||
|
"roles": userCtx.Roles,
|
||||||
|
"auth_provider": providerName,
|
||||||
|
}
|
||||||
|
|
||||||
|
userJSON, err := json.Marshal(userData)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to marshal user data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errMsg *string
|
||||||
|
var userID *int
|
||||||
|
|
||||||
|
err = a.db.QueryRowContext(ctx, `
|
||||||
|
SELECT p_success, p_error, p_user_id
|
||||||
|
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||||
|
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errMsg != nil {
|
||||||
|
return 0, fmt.Errorf("%s", *errMsg)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to get or create user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID == nil {
|
||||||
|
return 0, fmt.Errorf("user ID not returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
return *userID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// oauth2CreateSession creates a new OAuth2 session using stored procedure
|
||||||
|
func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, sessionToken string, userID int, token *oauth2.Token, expiresAt time.Time, providerName string) error {
|
||||||
|
sessionData := map[string]interface{}{
|
||||||
|
"session_token": sessionToken,
|
||||||
|
"user_id": userID,
|
||||||
|
"access_token": token.AccessToken,
|
||||||
|
"refresh_token": token.RefreshToken,
|
||||||
|
"token_type": token.TokenType,
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
"auth_provider": providerName,
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionJSON, err := json.Marshal(sessionData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal session data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errMsg *string
|
||||||
|
|
||||||
|
err = a.db.QueryRowContext(ctx, `
|
||||||
|
SELECT p_success, p_error
|
||||||
|
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||||
|
`, sessionJSON).Scan(&success, &errMsg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errMsg != nil {
|
||||||
|
return fmt.Errorf("%s", *errMsg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to create session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateState validates state using in-memory storage
|
||||||
|
func (p *OAuth2Provider) validateState(state string) bool {
|
||||||
|
p.statesMutex.Lock()
|
||||||
|
defer p.statesMutex.Unlock()
|
||||||
|
|
||||||
|
expiry, ok := p.states[state]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(expiry) {
|
||||||
|
delete(p.states, state)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(p.states, state) // One-time use
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupStates removes expired states periodically
|
||||||
|
func (p *OAuth2Provider) cleanupStates() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
p.statesMutex.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for state, expiry := range p.states {
|
||||||
|
if now.After(expiry) {
|
||||||
|
delete(p.states, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.statesMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultOAuth2UserInfoParser parses standard OAuth2 user info claims
|
||||||
|
func defaultOAuth2UserInfoParser(userInfo map[string]any) (*UserContext, error) {
|
||||||
|
ctx := &UserContext{
|
||||||
|
Claims: userInfo,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract standard claims
|
||||||
|
if sub, ok := userInfo["sub"].(string); ok {
|
||||||
|
ctx.RemoteID = sub
|
||||||
|
}
|
||||||
|
if email, ok := userInfo["email"].(string); ok {
|
||||||
|
ctx.Email = email
|
||||||
|
// Use email as username if name not available
|
||||||
|
ctx.UserName = strings.Split(email, "@")[0]
|
||||||
|
}
|
||||||
|
if name, ok := userInfo["name"].(string); ok {
|
||||||
|
ctx.UserName = name
|
||||||
|
}
|
||||||
|
if login, ok := userInfo["login"].(string); ok {
|
||||||
|
ctx.UserName = login // GitHub uses "login"
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.UserName == "" {
|
||||||
|
return nil, fmt.Errorf("could not extract username from user info")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2RefreshToken refreshes an expired OAuth2 access token using the refresh token
|
||||||
|
// Takes the refresh token and returns a new LoginResponse with updated tokens
|
||||||
|
func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshToken, providerName string) (*LoginResponse, error) {
|
||||||
|
provider, err := a.getOAuth2Provider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session by refresh token from database
|
||||||
|
var success bool
|
||||||
|
var errMsg *string
|
||||||
|
var sessionData []byte
|
||||||
|
|
||||||
|
err = a.db.QueryRowContext(ctx, `
|
||||||
|
SELECT p_success, p_error, p_data::text
|
||||||
|
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||||
|
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errMsg != nil {
|
||||||
|
return nil, fmt.Errorf("%s", *errMsg)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid or expired refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse session data
|
||||||
|
var session struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(sessionData, &session); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse session data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create oauth2.Token from stored data
|
||||||
|
oldToken := &oauth2.Token{
|
||||||
|
AccessToken: session.AccessToken,
|
||||||
|
TokenType: session.TokenType,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
Expiry: session.Expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use OAuth2 provider to refresh the token
|
||||||
|
tokenSource := provider.config.TokenSource(ctx, oldToken)
|
||||||
|
newToken, err := tokenSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to refresh token with provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new session token
|
||||||
|
newSessionToken, err := a.OAuth2GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate new session token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update session in database with new tokens
|
||||||
|
updateData := map[string]interface{}{
|
||||||
|
"user_id": session.UserID,
|
||||||
|
"old_refresh_token": refreshToken,
|
||||||
|
"new_session_token": newSessionToken,
|
||||||
|
"new_access_token": newToken.AccessToken,
|
||||||
|
"new_refresh_token": newToken.RefreshToken,
|
||||||
|
"expires_at": newToken.Expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
updateJSON, err := json.Marshal(updateData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal update data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateSuccess bool
|
||||||
|
var updateErrMsg *string
|
||||||
|
|
||||||
|
err = a.db.QueryRowContext(ctx, `
|
||||||
|
SELECT p_success, p_error
|
||||||
|
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||||
|
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !updateSuccess {
|
||||||
|
if updateErrMsg != nil {
|
||||||
|
return nil, fmt.Errorf("%s", *updateErrMsg)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to update session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user data
|
||||||
|
var userSuccess bool
|
||||||
|
var userErrMsg *string
|
||||||
|
var userData []byte
|
||||||
|
|
||||||
|
err = a.db.QueryRowContext(ctx, `
|
||||||
|
SELECT p_success, p_error, p_data::text
|
||||||
|
FROM resolvespec_oauth_getuser($1)
|
||||||
|
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userSuccess {
|
||||||
|
if userErrMsg != nil {
|
||||||
|
return nil, fmt.Errorf("%s", *userErrMsg)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get user data")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse user context
|
||||||
|
var userCtx UserContext
|
||||||
|
if err := json.Unmarshal(userData, &userCtx); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userCtx.SessionID = newSessionToken
|
||||||
|
|
||||||
|
return &LoginResponse{
|
||||||
|
Token: newSessionToken,
|
||||||
|
RefreshToken: newToken.RefreshToken,
|
||||||
|
User: &userCtx,
|
||||||
|
ExpiresIn: int64(time.Until(newToken.Expiry).Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-configured OAuth2 factory methods
|
||||||
|
|
||||||
|
// NewGoogleAuthenticator creates a DatabaseAuthenticator configured for Google OAuth2
|
||||||
|
func NewGoogleAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
return auth.WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
ProviderName: "google",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGitHubAuthenticator creates a DatabaseAuthenticator configured for GitHub OAuth2
|
||||||
|
func NewGitHubAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
return auth.WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
UserInfoURL: "https://api.github.com/user",
|
||||||
|
ProviderName: "github",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMicrosoftAuthenticator creates a DatabaseAuthenticator configured for Microsoft OAuth2
|
||||||
|
func NewMicrosoftAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
return auth.WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||||
|
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||||
|
ProviderName: "microsoft",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFacebookAuthenticator creates a DatabaseAuthenticator configured for Facebook OAuth2
|
||||||
|
func NewFacebookAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
return auth.WithOAuth2(OAuth2Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Scopes: []string{"email"},
|
||||||
|
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||||
|
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||||
|
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||||
|
ProviderName: "facebook",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiProviderAuthenticator creates a DatabaseAuthenticator with all major OAuth2 providers configured
|
||||||
|
func NewMultiProviderAuthenticator(db *sql.DB, configs map[string]OAuth2Config) *DatabaseAuthenticator {
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
|
||||||
|
//nolint:gocritic // OAuth2Config is copied but kept for API simplicity
|
||||||
|
for _, cfg := range configs {
|
||||||
|
auth.WithOAuth2(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth
|
||||||
|
}
|
||||||
185
pkg/security/passkey.go
Normal file
185
pkg/security/passkey.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PasskeyCredential represents a stored WebAuthn/FIDO2 credential
|
||||||
|
type PasskeyCredential struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
CredentialID []byte `json:"credential_id"` // Raw credential ID from authenticator
|
||||||
|
PublicKey []byte `json:"public_key"` // COSE public key
|
||||||
|
AttestationType string `json:"attestation_type"` // none, indirect, direct
|
||||||
|
AAGUID []byte `json:"aaguid"` // Authenticator AAGUID
|
||||||
|
SignCount uint32 `json:"sign_count"` // Signature counter
|
||||||
|
CloneWarning bool `json:"clone_warning"` // True if cloning detected
|
||||||
|
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||||
|
BackupEligible bool `json:"backup_eligible"` // Credential can be backed up
|
||||||
|
BackupState bool `json:"backup_state"` // Credential is currently backed up
|
||||||
|
Name string `json:"name,omitempty"` // User-friendly name
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
LastUsedAt time.Time `json:"last_used_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyRegistrationOptions contains options for beginning passkey registration
|
||||||
|
type PasskeyRegistrationOptions struct {
|
||||||
|
Challenge []byte `json:"challenge"`
|
||||||
|
RelyingParty PasskeyRelyingParty `json:"rp"`
|
||||||
|
User PasskeyUser `json:"user"`
|
||||||
|
PubKeyCredParams []PasskeyCredentialParam `json:"pubKeyCredParams"`
|
||||||
|
Timeout int64 `json:"timeout,omitempty"` // Milliseconds
|
||||||
|
ExcludeCredentials []PasskeyCredentialDescriptor `json:"excludeCredentials,omitempty"`
|
||||||
|
AuthenticatorSelection *PasskeyAuthenticatorSelection `json:"authenticatorSelection,omitempty"`
|
||||||
|
Attestation string `json:"attestation,omitempty"` // none, indirect, direct, enterprise
|
||||||
|
Extensions map[string]any `json:"extensions,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyAuthenticationOptions contains options for beginning passkey authentication
|
||||||
|
type PasskeyAuthenticationOptions struct {
|
||||||
|
Challenge []byte `json:"challenge"`
|
||||||
|
Timeout int64 `json:"timeout,omitempty"`
|
||||||
|
RelyingPartyID string `json:"rpId,omitempty"`
|
||||||
|
AllowCredentials []PasskeyCredentialDescriptor `json:"allowCredentials,omitempty"`
|
||||||
|
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||||
|
Extensions map[string]any `json:"extensions,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyRelyingParty identifies the relying party
|
||||||
|
type PasskeyRelyingParty struct {
|
||||||
|
ID string `json:"id"` // Domain (e.g., "example.com")
|
||||||
|
Name string `json:"name"` // Display name
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyUser identifies the user
|
||||||
|
type PasskeyUser struct {
|
||||||
|
ID []byte `json:"id"` // User handle (unique, persistent)
|
||||||
|
Name string `json:"name"` // Username
|
||||||
|
DisplayName string `json:"displayName"` // Display name
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyCredentialParam specifies supported public key algorithm
|
||||||
|
type PasskeyCredentialParam struct {
|
||||||
|
Type string `json:"type"` // "public-key"
|
||||||
|
Alg int `json:"alg"` // COSE algorithm identifier (e.g., -7 for ES256, -257 for RS256)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyCredentialDescriptor describes a credential
|
||||||
|
type PasskeyCredentialDescriptor struct {
|
||||||
|
Type string `json:"type"` // "public-key"
|
||||||
|
ID []byte `json:"id"` // Credential ID
|
||||||
|
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyAuthenticatorSelection specifies authenticator requirements
|
||||||
|
type PasskeyAuthenticatorSelection struct {
|
||||||
|
AuthenticatorAttachment string `json:"authenticatorAttachment,omitempty"` // platform, cross-platform
|
||||||
|
RequireResidentKey bool `json:"requireResidentKey,omitempty"`
|
||||||
|
ResidentKey string `json:"residentKey,omitempty"` // discouraged, preferred, required
|
||||||
|
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyRegistrationResponse contains the client's registration response
|
||||||
|
type PasskeyRegistrationResponse struct {
|
||||||
|
ID string `json:"id"` // Base64URL encoded credential ID
|
||||||
|
RawID []byte `json:"rawId"` // Raw credential ID
|
||||||
|
Type string `json:"type"` // "public-key"
|
||||||
|
Response PasskeyAuthenticatorAttestationResponse `json:"response"`
|
||||||
|
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||||
|
Transports []string `json:"transports,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyAuthenticatorAttestationResponse contains attestation data
|
||||||
|
type PasskeyAuthenticatorAttestationResponse struct {
|
||||||
|
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||||
|
AttestationObject []byte `json:"attestationObject"`
|
||||||
|
Transports []string `json:"transports,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyAuthenticationResponse contains the client's authentication response
|
||||||
|
type PasskeyAuthenticationResponse struct {
|
||||||
|
ID string `json:"id"` // Base64URL encoded credential ID
|
||||||
|
RawID []byte `json:"rawId"` // Raw credential ID
|
||||||
|
Type string `json:"type"` // "public-key"
|
||||||
|
Response PasskeyAuthenticatorAssertionResponse `json:"response"`
|
||||||
|
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyAuthenticatorAssertionResponse contains assertion data
|
||||||
|
type PasskeyAuthenticatorAssertionResponse struct {
|
||||||
|
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||||
|
AuthenticatorData []byte `json:"authenticatorData"`
|
||||||
|
Signature []byte `json:"signature"`
|
||||||
|
UserHandle []byte `json:"userHandle,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyProvider handles passkey registration and authentication
|
||||||
|
type PasskeyProvider interface {
|
||||||
|
// BeginRegistration creates registration options for a new passkey
|
||||||
|
BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error)
|
||||||
|
|
||||||
|
// CompleteRegistration verifies and stores a new passkey credential
|
||||||
|
CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error)
|
||||||
|
|
||||||
|
// BeginAuthentication creates authentication options for passkey login
|
||||||
|
BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error)
|
||||||
|
|
||||||
|
// CompleteAuthentication verifies a passkey assertion and returns the user
|
||||||
|
CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error)
|
||||||
|
|
||||||
|
// GetCredentials returns all passkey credentials for a user
|
||||||
|
GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error)
|
||||||
|
|
||||||
|
// DeleteCredential removes a passkey credential
|
||||||
|
DeleteCredential(ctx context.Context, userID int, credentialID string) error
|
||||||
|
|
||||||
|
// UpdateCredentialName updates the friendly name of a credential
|
||||||
|
UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyLoginRequest contains passkey authentication data
|
||||||
|
type PasskeyLoginRequest struct {
|
||||||
|
Response PasskeyAuthenticationResponse `json:"response"`
|
||||||
|
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||||
|
Claims map[string]any `json:"claims"` // Additional login data
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyRegisterRequest contains passkey registration data
|
||||||
|
type PasskeyRegisterRequest struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Response PasskeyRegistrationResponse `json:"response"`
|
||||||
|
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||||
|
CredentialName string `json:"credential_name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyBeginRegistrationRequest contains options for starting passkey registration
|
||||||
|
type PasskeyBeginRegistrationRequest struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyBeginAuthenticationRequest contains options for starting passkey authentication
|
||||||
|
type PasskeyBeginAuthenticationRequest struct {
|
||||||
|
Username string `json:"username,omitempty"` // Optional for resident key flow
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePasskeyRegistrationResponse parses a JSON passkey registration response
|
||||||
|
func ParsePasskeyRegistrationResponse(data []byte) (*PasskeyRegistrationResponse, error) {
|
||||||
|
var response PasskeyRegistrationResponse
|
||||||
|
if err := json.Unmarshal(data, &response); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePasskeyAuthenticationResponse parses a JSON passkey authentication response
|
||||||
|
func ParsePasskeyAuthenticationResponse(data []byte) (*PasskeyAuthenticationResponse, error) {
|
||||||
|
var response PasskeyAuthenticationResponse
|
||||||
|
if err := json.Unmarshal(data, &response); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
432
pkg/security/passkey_examples.go
Normal file
432
pkg/security/passkey_examples.go
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PasskeyAuthenticationExample demonstrates passkey (WebAuthn/FIDO2) authentication
|
||||||
|
func PasskeyAuthenticationExample() {
|
||||||
|
// Setup database connection
|
||||||
|
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||||
|
|
||||||
|
// Create passkey provider
|
||||||
|
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com", // Your domain
|
||||||
|
RPName: "Example Application", // Display name
|
||||||
|
RPOrigin: "https://example.com", // Expected origin
|
||||||
|
Timeout: 60000, // 60 seconds
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create authenticator with passkey support
|
||||||
|
// Option 1: Pass during creation
|
||||||
|
_ = NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||||
|
PasskeyProvider: passkeyProvider,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Option 2: Use WithPasskey method
|
||||||
|
auth := NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// === REGISTRATION FLOW ===
|
||||||
|
|
||||||
|
// Step 1: Begin registration
|
||||||
|
regOptions, _ := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "alice",
|
||||||
|
DisplayName: "Alice Smith",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send regOptions to client as JSON
|
||||||
|
// Client will call navigator.credentials.create() with these options
|
||||||
|
_ = regOptions
|
||||||
|
|
||||||
|
// Step 2: Complete registration (after client returns credential)
|
||||||
|
// This would come from the client's navigator.credentials.create() response
|
||||||
|
clientResponse := PasskeyRegistrationResponse{
|
||||||
|
ID: "base64-credential-id",
|
||||||
|
RawID: []byte("raw-credential-id"),
|
||||||
|
Type: "public-key",
|
||||||
|
Response: PasskeyAuthenticatorAttestationResponse{
|
||||||
|
ClientDataJSON: []byte("..."),
|
||||||
|
AttestationObject: []byte("..."),
|
||||||
|
},
|
||||||
|
Transports: []string{"internal"},
|
||||||
|
}
|
||||||
|
|
||||||
|
credential, _ := auth.CompletePasskeyRegistration(ctx, PasskeyRegisterRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Response: clientResponse,
|
||||||
|
ExpectedChallenge: regOptions.Challenge,
|
||||||
|
CredentialName: "My iPhone",
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("Registered credential: %s\n", credential.ID)
|
||||||
|
|
||||||
|
// === AUTHENTICATION FLOW ===
|
||||||
|
|
||||||
|
// Step 1: Begin authentication
|
||||||
|
authOptions, _ := auth.BeginPasskeyAuthentication(ctx, PasskeyBeginAuthenticationRequest{
|
||||||
|
Username: "alice", // Optional - omit for resident key flow
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send authOptions to client as JSON
|
||||||
|
// Client will call navigator.credentials.get() with these options
|
||||||
|
_ = authOptions
|
||||||
|
|
||||||
|
// Step 2: Complete authentication (after client returns assertion)
|
||||||
|
// This would come from the client's navigator.credentials.get() response
|
||||||
|
clientAssertion := PasskeyAuthenticationResponse{
|
||||||
|
ID: "base64-credential-id",
|
||||||
|
RawID: []byte("raw-credential-id"),
|
||||||
|
Type: "public-key",
|
||||||
|
Response: PasskeyAuthenticatorAssertionResponse{
|
||||||
|
ClientDataJSON: []byte("..."),
|
||||||
|
AuthenticatorData: []byte("..."),
|
||||||
|
Signature: []byte("..."),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
loginResponse, _ := auth.LoginWithPasskey(ctx, PasskeyLoginRequest{
|
||||||
|
Response: clientAssertion,
|
||||||
|
ExpectedChallenge: authOptions.Challenge,
|
||||||
|
Claims: map[string]any{
|
||||||
|
"ip_address": "192.168.1.1",
|
||||||
|
"user_agent": "Mozilla/5.0...",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("Logged in user: %s with token: %s\n",
|
||||||
|
loginResponse.User.UserName, loginResponse.Token)
|
||||||
|
|
||||||
|
// === CREDENTIAL MANAGEMENT ===
|
||||||
|
|
||||||
|
// Get all credentials for a user
|
||||||
|
credentials, _ := auth.GetPasskeyCredentials(ctx, 1)
|
||||||
|
for i := range credentials {
|
||||||
|
fmt.Printf("Credential: %s (created: %s, last used: %s)\n",
|
||||||
|
credentials[i].Name, credentials[i].CreatedAt, credentials[i].LastUsedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update credential name
|
||||||
|
_ = auth.UpdatePasskeyCredentialName(ctx, 1, credential.ID, "My New iPhone")
|
||||||
|
|
||||||
|
// Delete credential
|
||||||
|
_ = auth.DeletePasskeyCredential(ctx, 1, credential.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyHTTPHandlersExample shows HTTP handlers for passkey authentication
|
||||||
|
func PasskeyHTTPHandlersExample(auth *DatabaseAuthenticator) {
|
||||||
|
// Store challenges in session/cache in production
|
||||||
|
challenges := make(map[string][]byte)
|
||||||
|
|
||||||
|
// Begin registration endpoint
|
||||||
|
http.HandleFunc("/api/passkey/register/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
options, err := auth.BeginPasskeyRegistration(r.Context(), PasskeyBeginRegistrationRequest{
|
||||||
|
UserID: req.UserID,
|
||||||
|
Username: req.Username,
|
||||||
|
DisplayName: req.DisplayName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store challenge for verification (use session ID as key in production)
|
||||||
|
sessionID := "session-123"
|
||||||
|
challenges[sessionID] = options.Challenge
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(options)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Complete registration endpoint
|
||||||
|
http.HandleFunc("/api/passkey/register/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Response PasskeyRegistrationResponse `json:"response"`
|
||||||
|
CredentialName string `json:"credential_name"`
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Get stored challenge (from session in production)
|
||||||
|
sessionID := "session-123"
|
||||||
|
challenge := challenges[sessionID]
|
||||||
|
delete(challenges, sessionID)
|
||||||
|
|
||||||
|
credential, err := auth.CompletePasskeyRegistration(r.Context(), PasskeyRegisterRequest{
|
||||||
|
UserID: req.UserID,
|
||||||
|
Response: req.Response,
|
||||||
|
ExpectedChallenge: challenge,
|
||||||
|
CredentialName: req.CredentialName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(credential)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Begin authentication endpoint
|
||||||
|
http.HandleFunc("/api/passkey/login/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"` // Optional
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
options, err := auth.BeginPasskeyAuthentication(r.Context(), PasskeyBeginAuthenticationRequest{
|
||||||
|
Username: req.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store challenge for verification (use session ID as key in production)
|
||||||
|
sessionID := "session-456"
|
||||||
|
challenges[sessionID] = options.Challenge
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(options)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Complete authentication endpoint
|
||||||
|
http.HandleFunc("/api/passkey/login/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req struct {
|
||||||
|
Response PasskeyAuthenticationResponse `json:"response"`
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
// Get stored challenge (from session in production)
|
||||||
|
sessionID := "session-456"
|
||||||
|
challenge := challenges[sessionID]
|
||||||
|
delete(challenges, sessionID)
|
||||||
|
|
||||||
|
loginResponse, err := auth.LoginWithPasskey(r.Context(), PasskeyLoginRequest{
|
||||||
|
Response: req.Response,
|
||||||
|
ExpectedChallenge: challenge,
|
||||||
|
Claims: map[string]any{
|
||||||
|
"ip_address": r.RemoteAddr,
|
||||||
|
"user_agent": r.UserAgent(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set session cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: loginResponse.Token,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(loginResponse)
|
||||||
|
})
|
||||||
|
|
||||||
|
// List credentials endpoint
|
||||||
|
http.HandleFunc("/api/passkey/credentials", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get user from authenticated session
|
||||||
|
userCtx, err := auth.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials, err := auth.GetPasskeyCredentials(r.Context(), userCtx.UserID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(credentials)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Delete credential endpoint
|
||||||
|
http.HandleFunc("/api/passkey/credentials/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, err := auth.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
CredentialID string `json:"credential_id"`
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
err = auth.DeletePasskeyCredential(r.Context(), userCtx.UserID, req.CredentialID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasskeyClientSideExample shows the client-side JavaScript code needed
|
||||||
|
func PasskeyClientSideExample() string {
|
||||||
|
return `
|
||||||
|
// === CLIENT-SIDE JAVASCRIPT FOR PASSKEY AUTHENTICATION ===
|
||||||
|
|
||||||
|
// Helper function to convert base64 to ArrayBuffer
|
||||||
|
function base64ToArrayBuffer(base64) {
|
||||||
|
const binary = atob(base64);
|
||||||
|
const bytes = new Uint8Array(binary.length);
|
||||||
|
for (let i = 0; i < binary.length; i++) {
|
||||||
|
bytes[i] = binary.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return bytes.buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to convert ArrayBuffer to base64
|
||||||
|
function arrayBufferToBase64(buffer) {
|
||||||
|
const bytes = new Uint8Array(buffer);
|
||||||
|
let binary = '';
|
||||||
|
for (let i = 0; i < bytes.length; i++) {
|
||||||
|
binary += String.fromCharCode(bytes[i]);
|
||||||
|
}
|
||||||
|
return btoa(binary);
|
||||||
|
}
|
||||||
|
|
||||||
|
// === REGISTRATION ===
|
||||||
|
|
||||||
|
async function registerPasskey(userId, username, displayName) {
|
||||||
|
// Step 1: Get registration options from server
|
||||||
|
const optionsResponse = await fetch('/api/passkey/register/begin', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ user_id: userId, username, display_name: displayName })
|
||||||
|
});
|
||||||
|
const options = await optionsResponse.json();
|
||||||
|
|
||||||
|
// Convert base64 strings to ArrayBuffers
|
||||||
|
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||||
|
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||||
|
if (options.excludeCredentials) {
|
||||||
|
options.excludeCredentials = options.excludeCredentials.map(cred => ({
|
||||||
|
...cred,
|
||||||
|
id: base64ToArrayBuffer(cred.id)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Create credential using WebAuthn API
|
||||||
|
const credential = await navigator.credentials.create({
|
||||||
|
publicKey: options
|
||||||
|
});
|
||||||
|
|
||||||
|
// Step 3: Send credential to server
|
||||||
|
const credentialResponse = {
|
||||||
|
id: credential.id,
|
||||||
|
rawId: arrayBufferToBase64(credential.rawId),
|
||||||
|
type: credential.type,
|
||||||
|
response: {
|
||||||
|
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||||
|
attestationObject: arrayBufferToBase64(credential.response.attestationObject)
|
||||||
|
},
|
||||||
|
transports: credential.response.getTransports ? credential.response.getTransports() : []
|
||||||
|
};
|
||||||
|
|
||||||
|
const completeResponse = await fetch('/api/passkey/register/complete', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
user_id: userId,
|
||||||
|
response: credentialResponse,
|
||||||
|
credential_name: 'My Device'
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
return await completeResponse.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === AUTHENTICATION ===
|
||||||
|
|
||||||
|
async function loginWithPasskey(username) {
|
||||||
|
// Step 1: Get authentication options from server
|
||||||
|
const optionsResponse = await fetch('/api/passkey/login/begin', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ username })
|
||||||
|
});
|
||||||
|
const options = await optionsResponse.json();
|
||||||
|
|
||||||
|
// Convert base64 strings to ArrayBuffers
|
||||||
|
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||||
|
if (options.allowCredentials) {
|
||||||
|
options.allowCredentials = options.allowCredentials.map(cred => ({
|
||||||
|
...cred,
|
||||||
|
id: base64ToArrayBuffer(cred.id)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Get credential using WebAuthn API
|
||||||
|
const credential = await navigator.credentials.get({
|
||||||
|
publicKey: options
|
||||||
|
});
|
||||||
|
|
||||||
|
// Step 3: Send assertion to server
|
||||||
|
const assertionResponse = {
|
||||||
|
id: credential.id,
|
||||||
|
rawId: arrayBufferToBase64(credential.rawId),
|
||||||
|
type: credential.type,
|
||||||
|
response: {
|
||||||
|
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||||
|
authenticatorData: arrayBufferToBase64(credential.response.authenticatorData),
|
||||||
|
signature: arrayBufferToBase64(credential.response.signature),
|
||||||
|
userHandle: credential.response.userHandle ? arrayBufferToBase64(credential.response.userHandle) : null
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const loginResponse = await fetch('/api/passkey/login/complete', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ response: assertionResponse })
|
||||||
|
});
|
||||||
|
|
||||||
|
return await loginResponse.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
// === USAGE ===
|
||||||
|
|
||||||
|
// Register a new passkey
|
||||||
|
document.getElementById('register-btn').addEventListener('click', async () => {
|
||||||
|
try {
|
||||||
|
const result = await registerPasskey(1, 'alice', 'Alice Smith');
|
||||||
|
console.log('Passkey registered:', result);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Registration failed:', error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Login with passkey
|
||||||
|
document.getElementById('login-btn').addEventListener('click', async () => {
|
||||||
|
try {
|
||||||
|
const result = await loginWithPasskey('alice');
|
||||||
|
console.log('Logged in:', result);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Login failed:', error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
`
|
||||||
|
}
|
||||||
405
pkg/security/passkey_provider.go
Normal file
405
pkg/security/passkey_provider.go
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||||
|
type DatabasePasskeyProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
rpID string // Relying Party ID (domain)
|
||||||
|
rpName string // Relying Party display name
|
||||||
|
rpOrigin string // Expected origin for WebAuthn
|
||||||
|
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||||
|
type DatabasePasskeyProviderOptions struct {
|
||||||
|
// RPID is the Relying Party ID (typically your domain, e.g., "example.com")
|
||||||
|
RPID string
|
||||||
|
// RPName is the display name for your relying party
|
||||||
|
RPName string
|
||||||
|
// RPOrigin is the expected origin (e.g., "https://example.com")
|
||||||
|
RPOrigin string
|
||||||
|
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||||
|
Timeout int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||||
|
func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) *DatabasePasskeyProvider {
|
||||||
|
if opts.Timeout == 0 {
|
||||||
|
opts.Timeout = 60000 // 60 seconds default
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DatabasePasskeyProvider{
|
||||||
|
db: db,
|
||||||
|
rpID: opts.RPID,
|
||||||
|
rpName: opts.RPName,
|
||||||
|
rpOrigin: opts.RPOrigin,
|
||||||
|
timeout: opts.Timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginRegistration creates registration options for a new passkey
|
||||||
|
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||||
|
// Generate challenge
|
||||||
|
challenge := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(challenge); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing credentials to exclude
|
||||||
|
credentials, err := p.GetCredentials(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get existing credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
excludeCredentials := make([]PasskeyCredentialDescriptor, 0, len(credentials))
|
||||||
|
for i := range credentials {
|
||||||
|
excludeCredentials = append(excludeCredentials, PasskeyCredentialDescriptor{
|
||||||
|
Type: "public-key",
|
||||||
|
ID: credentials[i].CredentialID,
|
||||||
|
Transports: credentials[i].Transports,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create user handle (persistent user ID)
|
||||||
|
userHandle := []byte(fmt.Sprintf("user_%d", userID))
|
||||||
|
|
||||||
|
return &PasskeyRegistrationOptions{
|
||||||
|
Challenge: challenge,
|
||||||
|
RelyingParty: PasskeyRelyingParty{
|
||||||
|
ID: p.rpID,
|
||||||
|
Name: p.rpName,
|
||||||
|
},
|
||||||
|
User: PasskeyUser{
|
||||||
|
ID: userHandle,
|
||||||
|
Name: username,
|
||||||
|
DisplayName: displayName,
|
||||||
|
},
|
||||||
|
PubKeyCredParams: []PasskeyCredentialParam{
|
||||||
|
{Type: "public-key", Alg: -7}, // ES256 (ECDSA with SHA-256)
|
||||||
|
{Type: "public-key", Alg: -257}, // RS256 (RSASSA-PKCS1-v1_5 with SHA-256)
|
||||||
|
},
|
||||||
|
Timeout: p.timeout,
|
||||||
|
ExcludeCredentials: excludeCredentials,
|
||||||
|
AuthenticatorSelection: &PasskeyAuthenticatorSelection{
|
||||||
|
RequireResidentKey: false,
|
||||||
|
ResidentKey: "preferred",
|
||||||
|
UserVerification: "preferred",
|
||||||
|
},
|
||||||
|
Attestation: "none",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteRegistration verifies and stores a new passkey credential
|
||||||
|
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||||
|
// like github.com/go-webauthn/webauthn to properly verify attestation and parse credentials.
|
||||||
|
func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error) {
|
||||||
|
// TODO: Implement full WebAuthn verification
|
||||||
|
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||||
|
// 2. Parse and verify attestationObject
|
||||||
|
// 3. Extract public key and credential ID
|
||||||
|
// 4. Verify attestation signature (if not "none")
|
||||||
|
|
||||||
|
// For now, this is a placeholder that stores the credential data
|
||||||
|
// In production, you MUST use a proper WebAuthn library
|
||||||
|
|
||||||
|
credData := map[string]any{
|
||||||
|
"user_id": userID,
|
||||||
|
"credential_id": base64.StdEncoding.EncodeToString(response.RawID),
|
||||||
|
"public_key": base64.StdEncoding.EncodeToString(response.Response.AttestationObject),
|
||||||
|
"attestation_type": "none",
|
||||||
|
"sign_count": 0,
|
||||||
|
"transports": response.Transports,
|
||||||
|
"backup_eligible": false,
|
||||||
|
"backup_state": false,
|
||||||
|
"name": "Passkey",
|
||||||
|
}
|
||||||
|
|
||||||
|
credJSON, err := json.Marshal(credData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal credential data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var credentialID sql.NullInt64
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||||
|
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to store credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PasskeyCredential{
|
||||||
|
ID: fmt.Sprintf("%d", credentialID.Int64),
|
||||||
|
UserID: userID,
|
||||||
|
CredentialID: response.RawID,
|
||||||
|
PublicKey: response.Response.AttestationObject,
|
||||||
|
AttestationType: "none",
|
||||||
|
Transports: response.Transports,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
LastUsedAt: time.Now(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginAuthentication creates authentication options for passkey login
|
||||||
|
func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error) {
|
||||||
|
// Generate challenge
|
||||||
|
challenge := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(challenge); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate challenge: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If username is provided, get user's credentials
|
||||||
|
var allowCredentials []PasskeyCredentialDescriptor
|
||||||
|
if username != "" {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var userID sql.NullInt64
|
||||||
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||||
|
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse credentials
|
||||||
|
var creds []struct {
|
||||||
|
ID string `json:"credential_id"`
|
||||||
|
Transports []string `json:"transports"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(credentialsJSON.String), &creds); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowCredentials = make([]PasskeyCredentialDescriptor, 0, len(creds))
|
||||||
|
for _, cred := range creds {
|
||||||
|
credID, err := base64.StdEncoding.DecodeString(cred.ID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allowCredentials = append(allowCredentials, PasskeyCredentialDescriptor{
|
||||||
|
Type: "public-key",
|
||||||
|
ID: credID,
|
||||||
|
Transports: cred.Transports,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PasskeyAuthenticationOptions{
|
||||||
|
Challenge: challenge,
|
||||||
|
Timeout: p.timeout,
|
||||||
|
RelyingPartyID: p.rpID,
|
||||||
|
AllowCredentials: allowCredentials,
|
||||||
|
UserVerification: "preferred",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteAuthentication verifies a passkey assertion and returns the user ID
|
||||||
|
// NOTE: This is a simplified implementation. In production, you should use a WebAuthn library
|
||||||
|
// like github.com/go-webauthn/webauthn to properly verify the assertion signature.
|
||||||
|
func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error) {
|
||||||
|
// TODO: Implement full WebAuthn verification
|
||||||
|
// 1. Verify clientDataJSON contains correct challenge and origin
|
||||||
|
// 2. Verify authenticatorData
|
||||||
|
// 3. Verify signature using stored public key
|
||||||
|
// 4. Update sign counter and check for cloning
|
||||||
|
|
||||||
|
// Get credential from database
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var credentialJSON sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||||
|
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return 0, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("credential not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse credential
|
||||||
|
var cred struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
SignCount uint32 `json:"sign_count"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(credentialJSON.String), &cred); err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to parse credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Verify signature here
|
||||||
|
// For now, we'll just update the counter as a placeholder
|
||||||
|
|
||||||
|
// Update counter (in production, this should be done after successful verification)
|
||||||
|
newCounter := cred.SignCount + 1
|
||||||
|
var updateSuccess bool
|
||||||
|
var updateError sql.NullString
|
||||||
|
var cloneWarning sql.NullBool
|
||||||
|
|
||||||
|
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||||
|
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cloneWarning.Valid && cloneWarning.Bool {
|
||||||
|
return 0, fmt.Errorf("credential cloning detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred.UserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCredentials returns all passkey credentials for a user
|
||||||
|
func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||||
|
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get credentials")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse credentials
|
||||||
|
var rawCreds []struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
CredentialID string `json:"credential_id"`
|
||||||
|
PublicKey string `json:"public_key"`
|
||||||
|
AttestationType string `json:"attestation_type"`
|
||||||
|
AAGUID string `json:"aaguid"`
|
||||||
|
SignCount uint32 `json:"sign_count"`
|
||||||
|
CloneWarning bool `json:"clone_warning"`
|
||||||
|
Transports []string `json:"transports"`
|
||||||
|
BackupEligible bool `json:"backup_eligible"`
|
||||||
|
BackupState bool `json:"backup_state"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
LastUsedAt time.Time `json:"last_used_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(credentialsJSON.String), &rawCreds); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := make([]PasskeyCredential, 0, len(rawCreds))
|
||||||
|
for i := range rawCreds {
|
||||||
|
raw := rawCreds[i]
|
||||||
|
credID, err := base64.StdEncoding.DecodeString(raw.CredentialID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pubKey, err := base64.StdEncoding.DecodeString(raw.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aaguid, _ := base64.StdEncoding.DecodeString(raw.AAGUID)
|
||||||
|
|
||||||
|
credentials = append(credentials, PasskeyCredential{
|
||||||
|
ID: fmt.Sprintf("%d", raw.ID),
|
||||||
|
UserID: raw.UserID,
|
||||||
|
CredentialID: credID,
|
||||||
|
PublicKey: pubKey,
|
||||||
|
AttestationType: raw.AttestationType,
|
||||||
|
AAGUID: aaguid,
|
||||||
|
SignCount: raw.SignCount,
|
||||||
|
CloneWarning: raw.CloneWarning,
|
||||||
|
Transports: raw.Transports,
|
||||||
|
BackupEligible: raw.BackupEligible,
|
||||||
|
BackupState: raw.BackupState,
|
||||||
|
Name: raw.Name,
|
||||||
|
CreatedAt: raw.CreatedAt,
|
||||||
|
LastUsedAt: raw.LastUsedAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteCredential removes a passkey credential
|
||||||
|
func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID int, credentialID string) error {
|
||||||
|
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid credential ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||||
|
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete credential: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to delete credential")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCredentialName updates the friendly name of a credential
|
||||||
|
func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||||
|
credID, err := base64.StdEncoding.DecodeString(credentialID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid credential ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||||
|
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update credential name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to update credential name")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
330
pkg/security/passkey_test.go
Normal file
330
pkg/security/passkey_test.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDatabasePasskeyProvider_BeginRegistration(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
RPOrigin: "https://example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Mock get credentials query
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||||
|
AddRow(true, nil, "[]")
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
opts, err := provider.BeginRegistration(ctx, 1, "testuser", "Test User")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BeginRegistration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.RelyingParty.ID != "example.com" {
|
||||||
|
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingParty.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.User.Name != "testuser" {
|
||||||
|
t.Errorf("expected username 'testuser', got '%s'", opts.User.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.Challenge) != 32 {
|
||||||
|
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.PubKeyCredParams) != 2 {
|
||||||
|
t.Errorf("expected 2 credential params, got %d", len(opts.PubKeyCredParams))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabasePasskeyProvider_BeginAuthentication(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
RPOrigin: "https://example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Mock get credentials by username query
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user_id", "p_credentials"}).
|
||||||
|
AddRow(true, nil, 1, `[{"credential_id":"YWJjZGVm","transports":["internal"]}]`)
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username`).
|
||||||
|
WithArgs("testuser").
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
opts, err := provider.BeginAuthentication(ctx, "testuser")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BeginAuthentication failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.RelyingPartyID != "example.com" {
|
||||||
|
t.Errorf("expected RP ID 'example.com', got '%s'", opts.RelyingPartyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.Challenge) != 32 {
|
||||||
|
t.Errorf("expected challenge length 32, got %d", len(opts.Challenge))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.AllowCredentials) != 1 {
|
||||||
|
t.Errorf("expected 1 allowed credential, got %d", len(opts.AllowCredentials))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabasePasskeyProvider_GetCredentials(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
credentialsJSON := `[{
|
||||||
|
"id": 1,
|
||||||
|
"user_id": 1,
|
||||||
|
"credential_id": "YWJjZGVmMTIzNDU2",
|
||||||
|
"public_key": "cHVibGlja2V5",
|
||||||
|
"attestation_type": "none",
|
||||||
|
"aaguid": "",
|
||||||
|
"sign_count": 5,
|
||||||
|
"clone_warning": false,
|
||||||
|
"transports": ["internal"],
|
||||||
|
"backup_eligible": true,
|
||||||
|
"backup_state": false,
|
||||||
|
"name": "My Phone",
|
||||||
|
"created_at": "2026-01-01T00:00:00Z",
|
||||||
|
"last_used_at": "2026-01-31T00:00:00Z"
|
||||||
|
}]`
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||||
|
AddRow(true, nil, credentialsJSON)
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
credentials, err := provider.GetCredentials(ctx, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredentials failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(credentials) != 1 {
|
||||||
|
t.Fatalf("expected 1 credential, got %d", len(credentials))
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := credentials[0]
|
||||||
|
if cred.UserID != 1 {
|
||||||
|
t.Errorf("expected user ID 1, got %d", cred.UserID)
|
||||||
|
}
|
||||||
|
if cred.Name != "My Phone" {
|
||||||
|
t.Errorf("expected name 'My Phone', got '%s'", cred.Name)
|
||||||
|
}
|
||||||
|
if cred.SignCount != 5 {
|
||||||
|
t.Errorf("expected sign count 5, got %d", cred.SignCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabasePasskeyProvider_DeleteCredential(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||||
|
AddRow(true, nil)
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_delete_credential`).
|
||||||
|
WithArgs(1, sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
err = provider.DeleteCredential(ctx, 1, "YWJjZGVmMTIzNDU2")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("DeleteCredential failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabasePasskeyProvider_UpdateCredentialName(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error"}).
|
||||||
|
AddRow(true, nil)
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error FROM resolvespec_passkey_update_name`).
|
||||||
|
WithArgs(1, sqlmock.AnyArg(), "New Name").
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
err = provider.UpdateCredentialName(ctx, 1, "YWJjZGVmMTIzNDU2", "New Name")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("UpdateCredentialName failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseAuthenticator_PasskeyMethods(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
})
|
||||||
|
|
||||||
|
auth := NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||||
|
PasskeyProvider: passkeyProvider,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("BeginPasskeyRegistration", func(t *testing.T) {
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||||
|
AddRow(true, nil, "[]")
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
opts, err := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
DisplayName: "Test User",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("BeginPasskeyRegistration failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts == nil {
|
||||||
|
t.Error("expected options, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetPasskeyCredentials", func(t *testing.T) {
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_credentials"}).
|
||||||
|
AddRow(true, nil, "[]")
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials`).
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
credentials, err := auth.GetPasskeyCredentials(ctx, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GetPasskeyCredentials failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentials == nil {
|
||||||
|
t.Error("expected credentials slice, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseAuthenticator_WithoutPasskey(t *testing.T) {
|
||||||
|
db, _, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mock db: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
auth := NewDatabaseAuthenticator(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
DisplayName: "Test User",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when passkey provider not configured, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMsg := "passkey provider not configured"
|
||||||
|
if err.Error() != expectedMsg {
|
||||||
|
t.Errorf("expected error '%s', got '%s'", expectedMsg, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPasskeyProvider_NilDB(t *testing.T) {
|
||||||
|
// This test verifies that the provider can be created with nil DB
|
||||||
|
// but operations will fail. In production, always provide a valid DB.
|
||||||
|
var db *sql.DB
|
||||||
|
provider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||||
|
RPID: "example.com",
|
||||||
|
RPName: "Example App",
|
||||||
|
})
|
||||||
|
|
||||||
|
if provider == nil {
|
||||||
|
t.Error("expected provider to be created even with nil DB")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the provider has the correct configuration
|
||||||
|
if provider.rpID != "example.com" {
|
||||||
|
t.Errorf("expected RP ID 'example.com', got '%s'", provider.rpID)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
@@ -60,10 +61,19 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
|||||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||||
// resolvespec_session_update, resolvespec_refresh_token
|
// resolvespec_session_update, resolvespec_refresh_token
|
||||||
// See database_schema.sql for procedure definitions
|
// See database_schema.sql for procedure definitions
|
||||||
|
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||||
|
// Also supports passkey authentication configured with WithPasskey()
|
||||||
type DatabaseAuthenticator struct {
|
type DatabaseAuthenticator struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
cache *cache.Cache
|
cache *cache.Cache
|
||||||
cacheTTL time.Duration
|
cacheTTL time.Duration
|
||||||
|
|
||||||
|
// OAuth2 providers registry (multiple providers supported)
|
||||||
|
oauth2Providers map[string]*OAuth2Provider
|
||||||
|
oauth2ProvidersMutex sync.RWMutex
|
||||||
|
|
||||||
|
// Passkey provider (optional)
|
||||||
|
passkeyProvider PasskeyProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseAuthenticatorOptions configures the database authenticator
|
// DatabaseAuthenticatorOptions configures the database authenticator
|
||||||
@@ -73,6 +83,8 @@ type DatabaseAuthenticatorOptions struct {
|
|||||||
CacheTTL time.Duration
|
CacheTTL time.Duration
|
||||||
// Cache is an optional cache instance. If nil, uses the default cache
|
// Cache is an optional cache instance. If nil, uses the default cache
|
||||||
Cache *cache.Cache
|
Cache *cache.Cache
|
||||||
|
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||||
|
PasskeyProvider PasskeyProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||||
@@ -95,6 +107,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
|||||||
db: db,
|
db: db,
|
||||||
cache: cacheInstance,
|
cache: cacheInstance,
|
||||||
cacheTTL: opts.CacheTTL,
|
cacheTTL: opts.CacheTTL,
|
||||||
|
passkeyProvider: opts.PasskeyProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,6 +145,41 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
|||||||
return &response, nil
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register implements Registrable interface
|
||||||
|
func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error) {
|
||||||
|
// Convert RegisterRequest to JSON
|
||||||
|
reqJSON, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call resolvespec_register stored procedure
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||||
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("register query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("registration failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var response LoginResponse
|
||||||
|
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse register response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
// Convert LogoutRequest to JSON
|
// Convert LogoutRequest to JSON
|
||||||
reqJSON, err := json.Marshal(req)
|
reqJSON, err := json.Marshal(req)
|
||||||
@@ -654,3 +702,135 @@ func generateRandomString(length int) string {
|
|||||||
// }
|
// }
|
||||||
// return ""
|
// return ""
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// Passkey authentication methods
|
||||||
|
// ==============================
|
||||||
|
|
||||||
|
// WithPasskey configures the DatabaseAuthenticator with a passkey provider
|
||||||
|
func (a *DatabaseAuthenticator) WithPasskey(provider PasskeyProvider) *DatabaseAuthenticator {
|
||||||
|
a.passkeyProvider = provider
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginPasskeyRegistration initiates passkey registration for a user
|
||||||
|
func (a *DatabaseAuthenticator) BeginPasskeyRegistration(ctx context.Context, req PasskeyBeginRegistrationRequest) (*PasskeyRegistrationOptions, error) {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return nil, fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
return a.passkeyProvider.BeginRegistration(ctx, req.UserID, req.Username, req.DisplayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletePasskeyRegistration completes passkey registration
|
||||||
|
func (a *DatabaseAuthenticator) CompletePasskeyRegistration(ctx context.Context, req PasskeyRegisterRequest) (*PasskeyCredential, error) {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return nil, fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := a.passkeyProvider.CompleteRegistration(ctx, req.UserID, req.Response, req.ExpectedChallenge)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update credential name if provided
|
||||||
|
if req.CredentialName != "" && cred.ID != "" {
|
||||||
|
_ = a.passkeyProvider.UpdateCredentialName(ctx, req.UserID, cred.ID, req.CredentialName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginPasskeyAuthentication initiates passkey authentication
|
||||||
|
func (a *DatabaseAuthenticator) BeginPasskeyAuthentication(ctx context.Context, req PasskeyBeginAuthenticationRequest) (*PasskeyAuthenticationOptions, error) {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return nil, fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
return a.passkeyProvider.BeginAuthentication(ctx, req.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithPasskey authenticates a user using a passkey and creates a session
|
||||||
|
func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req PasskeyLoginRequest) (*LoginResponse, error) {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return nil, fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify passkey assertion
|
||||||
|
userID, err := a.passkeyProvider.CompleteAuthentication(ctx, req.Response, req.ExpectedChallenge)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user data from database
|
||||||
|
var username, email, roles string
|
||||||
|
var userLevel int
|
||||||
|
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||||
|
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session token
|
||||||
|
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||||
|
expiresAt := time.Now().Add(24 * time.Hour)
|
||||||
|
|
||||||
|
// Extract IP and user agent from claims
|
||||||
|
ipAddress := ""
|
||||||
|
userAgent := ""
|
||||||
|
if req.Claims != nil {
|
||||||
|
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||||
|
ipAddress = ip
|
||||||
|
}
|
||||||
|
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||||
|
userAgent = ua
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create session
|
||||||
|
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, now())`
|
||||||
|
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last login
|
||||||
|
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||||
|
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||||
|
|
||||||
|
// Return login response
|
||||||
|
return &LoginResponse{
|
||||||
|
Token: sessionToken,
|
||||||
|
User: &UserContext{
|
||||||
|
UserID: userID,
|
||||||
|
UserName: username,
|
||||||
|
Email: email,
|
||||||
|
UserLevel: userLevel,
|
||||||
|
SessionID: sessionToken,
|
||||||
|
Roles: parseRoles(roles),
|
||||||
|
},
|
||||||
|
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||||
|
func (a *DatabaseAuthenticator) GetPasskeyCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error) {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return nil, fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
return a.passkeyProvider.GetCredentials(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePasskeyCredential removes a passkey credential
|
||||||
|
func (a *DatabaseAuthenticator) DeletePasskeyCredential(ctx context.Context, userID int, credentialID string) error {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
return a.passkeyProvider.DeleteCredential(ctx, userID, credentialID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePasskeyCredentialName updates the friendly name of a credential
|
||||||
|
func (a *DatabaseAuthenticator) UpdatePasskeyCredentialName(ctx context.Context, userID int, credentialID string, name string) error {
|
||||||
|
if a.passkeyProvider == nil {
|
||||||
|
return fmt.Errorf("passkey provider not configured")
|
||||||
|
}
|
||||||
|
return a.passkeyProvider.UpdateCredentialName(ctx, userID, credentialID, name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -635,6 +635,94 @@ func TestDatabaseAuthenticator(t *testing.T) {
|
|||||||
t.Errorf("unfulfilled expectations: %v", err)
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("successful registration", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
req := RegisterRequest{
|
||||||
|
Username: "newuser",
|
||||||
|
Password: "password123",
|
||||||
|
Email: "newuser@example.com",
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||||
|
AddRow(true, nil, `{"token":"abc123","user":{"user_id":1,"user_name":"newuser","email":"newuser@example.com"},"expires_in":86400}`)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
resp, err := auth.Register(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token != "abc123" {
|
||||||
|
t.Errorf("expected token abc123, got %s", resp.Token)
|
||||||
|
}
|
||||||
|
if resp.User.UserName != "newuser" {
|
||||||
|
t.Errorf("expected username newuser, got %s", resp.User.UserName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("registration with duplicate username", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
req := RegisterRequest{
|
||||||
|
Username: "existinguser",
|
||||||
|
Password: "password123",
|
||||||
|
Email: "new@example.com",
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||||
|
AddRow(false, "Username already exists", nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
_, err := auth.Register(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate username")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("registration with duplicate email", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
req := RegisterRequest{
|
||||||
|
Username: "newuser2",
|
||||||
|
Password: "password123",
|
||||||
|
Email: "existing@example.com",
|
||||||
|
UserLevel: 1,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||||
|
AddRow(false, "Email already exists", nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||||
|
WithArgs(sqlmock.AnyArg()).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
_, err := auth.Register(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate email")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test DatabaseAuthenticator RefreshToken
|
// Test DatabaseAuthenticator RefreshToken
|
||||||
|
|||||||
188
pkg/security/totp.go
Normal file
188
pkg/security/totp.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/base32"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TwoFactorAuthProvider defines interface for 2FA operations
|
||||||
|
type TwoFactorAuthProvider interface {
|
||||||
|
// Generate2FASecret creates a new secret for a user
|
||||||
|
Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error)
|
||||||
|
|
||||||
|
// Validate2FACode verifies a TOTP code
|
||||||
|
Validate2FACode(secret string, code string) (bool, error)
|
||||||
|
|
||||||
|
// Enable2FA activates 2FA for a user (store secret in your database)
|
||||||
|
Enable2FA(userID int, secret string, backupCodes []string) error
|
||||||
|
|
||||||
|
// Disable2FA deactivates 2FA for a user
|
||||||
|
Disable2FA(userID int) error
|
||||||
|
|
||||||
|
// Get2FAStatus checks if user has 2FA enabled
|
||||||
|
Get2FAStatus(userID int) (bool, error)
|
||||||
|
|
||||||
|
// Get2FASecret retrieves the user's 2FA secret
|
||||||
|
Get2FASecret(userID int) (string, error)
|
||||||
|
|
||||||
|
// GenerateBackupCodes creates backup codes for 2FA
|
||||||
|
GenerateBackupCodes(userID int, count int) ([]string, error)
|
||||||
|
|
||||||
|
// ValidateBackupCode checks and consumes a backup code
|
||||||
|
ValidateBackupCode(userID int, code string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TwoFactorSecret contains 2FA setup information
|
||||||
|
type TwoFactorSecret struct {
|
||||||
|
Secret string `json:"secret"` // Base32 encoded secret
|
||||||
|
QRCodeURL string `json:"qr_code_url"` // URL for QR code generation
|
||||||
|
BackupCodes []string `json:"backup_codes"` // One-time backup codes
|
||||||
|
Issuer string `json:"issuer"` // Application name
|
||||||
|
AccountName string `json:"account_name"` // User identifier (email/username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TwoFactorConfig holds TOTP configuration
|
||||||
|
type TwoFactorConfig struct {
|
||||||
|
Algorithm string // SHA1, SHA256, SHA512
|
||||||
|
Digits int // Number of digits in code (6 or 8)
|
||||||
|
Period int // Time step in seconds (default 30)
|
||||||
|
SkewWindow int // Number of time steps to check before/after (default 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultTwoFactorConfig returns standard TOTP configuration
|
||||||
|
func DefaultTwoFactorConfig() *TwoFactorConfig {
|
||||||
|
return &TwoFactorConfig{
|
||||||
|
Algorithm: "SHA1",
|
||||||
|
Digits: 6,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TOTPGenerator handles TOTP code generation and validation
|
||||||
|
type TOTPGenerator struct {
|
||||||
|
config *TwoFactorConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTOTPGenerator creates a new TOTP generator with config
|
||||||
|
func NewTOTPGenerator(config *TwoFactorConfig) *TOTPGenerator {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultTwoFactorConfig()
|
||||||
|
}
|
||||||
|
return &TOTPGenerator{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSecret creates a random base32-encoded secret
|
||||||
|
func (t *TOTPGenerator) GenerateSecret() (string, error) {
|
||||||
|
secret := make([]byte, 20)
|
||||||
|
_, err := rand.Read(secret)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random secret: %w", err)
|
||||||
|
}
|
||||||
|
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateQRCodeURL creates a URL for QR code generation
|
||||||
|
func (t *TOTPGenerator) GenerateQRCodeURL(secret, issuer, accountName string) string {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("secret", secret)
|
||||||
|
params.Set("issuer", issuer)
|
||||||
|
params.Set("algorithm", t.config.Algorithm)
|
||||||
|
params.Set("digits", fmt.Sprintf("%d", t.config.Digits))
|
||||||
|
params.Set("period", fmt.Sprintf("%d", t.config.Period))
|
||||||
|
|
||||||
|
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, accountName))
|
||||||
|
return fmt.Sprintf("otpauth://totp/%s?%s", label, params.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCode creates a TOTP code for a given time
|
||||||
|
func (t *TOTPGenerator) GenerateCode(secret string, timestamp time.Time) (string, error) {
|
||||||
|
// Decode secret
|
||||||
|
key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate counter (time steps since Unix epoch)
|
||||||
|
counter := uint64(timestamp.Unix()) / uint64(t.config.Period)
|
||||||
|
|
||||||
|
// Generate HMAC
|
||||||
|
h := t.getHashFunc()
|
||||||
|
mac := hmac.New(h, key)
|
||||||
|
|
||||||
|
// Convert counter to 8-byte array
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(buf, counter)
|
||||||
|
mac.Write(buf)
|
||||||
|
|
||||||
|
sum := mac.Sum(nil)
|
||||||
|
|
||||||
|
// Dynamic truncation
|
||||||
|
offset := sum[len(sum)-1] & 0x0f
|
||||||
|
truncated := binary.BigEndian.Uint32(sum[offset:]) & 0x7fffffff
|
||||||
|
|
||||||
|
// Generate code with specified digits
|
||||||
|
code := truncated % uint32(math.Pow10(t.config.Digits))
|
||||||
|
|
||||||
|
format := fmt.Sprintf("%%0%dd", t.config.Digits)
|
||||||
|
return fmt.Sprintf(format, code), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCode checks if a code is valid for the secret
|
||||||
|
func (t *TOTPGenerator) ValidateCode(secret, code string) (bool, error) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Check current time and skew window
|
||||||
|
for i := -t.config.SkewWindow; i <= t.config.SkewWindow; i++ {
|
||||||
|
timestamp := now.Add(time.Duration(i*t.config.Period) * time.Second)
|
||||||
|
expected, err := t.GenerateCode(secret, timestamp)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if code == expected {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHashFunc returns the hash function based on algorithm
|
||||||
|
func (t *TOTPGenerator) getHashFunc() func() hash.Hash {
|
||||||
|
switch strings.ToUpper(t.config.Algorithm) {
|
||||||
|
case "SHA256":
|
||||||
|
return sha256.New
|
||||||
|
case "SHA512":
|
||||||
|
return sha512.New
|
||||||
|
default:
|
||||||
|
return sha1.New
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes creates random backup codes
|
||||||
|
func GenerateBackupCodes(count int) ([]string, error) {
|
||||||
|
codes := make([]string, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
code := make([]byte, 4)
|
||||||
|
_, err := rand.Read(code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate backup code: %w", err)
|
||||||
|
}
|
||||||
|
codes[i] = fmt.Sprintf("%08X", binary.BigEndian.Uint32(code))
|
||||||
|
}
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
399
pkg/security/totp_integration_test.go
Normal file
399
pkg/security/totp_integration_test.go
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
package security_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
|
||||||
|
// MockAuthenticator is a simple authenticator for testing 2FA
|
||||||
|
type MockAuthenticator struct {
|
||||||
|
users map[string]*security.UserContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockAuthenticator() *MockAuthenticator {
|
||||||
|
return &MockAuthenticator{
|
||||||
|
users: map[string]*security.UserContext{
|
||||||
|
"testuser": {
|
||||||
|
UserID: 1,
|
||||||
|
UserName: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
|
||||||
|
user, exists := m.users[req.Username]
|
||||||
|
if !exists || req.Password != "password" {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
return &security.LoginResponse{
|
||||||
|
Token: "mock-token",
|
||||||
|
RefreshToken: "mock-refresh-token",
|
||||||
|
User: user,
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||||
|
return m.users["testuser"], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Setup(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup 2FA
|
||||||
|
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.Secret == "" {
|
||||||
|
t.Error("Setup2FA() returned empty secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.QRCodeURL == "" {
|
||||||
|
t.Error("Setup2FA() returned empty QR code URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(secret.BackupCodes) == 0 {
|
||||||
|
t.Error("Setup2FA() returned no backup codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.Issuer != "TestApp" {
|
||||||
|
t.Errorf("Setup2FA() Issuer = %s, want TestApp", secret.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.AccountName != "test@example.com" {
|
||||||
|
t.Errorf("Setup2FA() AccountName = %s, want test@example.com", secret.AccountName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Enable2FA(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup 2FA
|
||||||
|
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate valid code
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, err := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 2FA with valid code
|
||||||
|
err = tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Enable2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 2FA is enabled
|
||||||
|
status, err := provider.Get2FAStatus(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !status {
|
||||||
|
t.Error("Enable2FA() did not enable 2FA")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Enable2FA_InvalidCode(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup 2FA
|
||||||
|
secret, err := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to enable with invalid code
|
||||||
|
err = tfaAuth.Enable2FA(1, secret.Secret, "000000")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Enable2FA() should fail with invalid code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 2FA is not enabled
|
||||||
|
status, _ := provider.Get2FAStatus(1)
|
||||||
|
if status {
|
||||||
|
t.Error("Enable2FA() should not enable 2FA with invalid code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Login_Without2FA(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Requires2FA {
|
||||||
|
t.Error("Login() should not require 2FA when not enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token == "" {
|
||||||
|
t.Error("Login() should return token when 2FA not required")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Login_With2FA_NoCode(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Try to login without 2FA code
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.Requires2FA {
|
||||||
|
t.Error("Login() should require 2FA when enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token != "" {
|
||||||
|
t.Error("Login() should not return token when 2FA required but not provided")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Login_With2FA_ValidCode(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Generate new valid code for login
|
||||||
|
newCode, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
|
||||||
|
// Login with 2FA code
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: newCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Requires2FA {
|
||||||
|
t.Error("Login() should not require 2FA when valid code provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token == "" {
|
||||||
|
t.Error("Login() should return token when 2FA validated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.User.TwoFactorEnabled {
|
||||||
|
t.Error("Login() should set TwoFactorEnabled on user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Login_With2FA_InvalidCode(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Try to login with invalid code
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: "000000",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Login() should fail with invalid 2FA code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Login_WithBackupCode(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Get backup codes
|
||||||
|
backupCodes, _ := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||||
|
|
||||||
|
// Login with backup code
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: backupCodes[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() with backup code error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token == "" {
|
||||||
|
t.Error("Login() should return token when backup code validated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to use same backup code again
|
||||||
|
req2 := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: backupCodes[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tfaAuth.Login(context.Background(), req2)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Login() should fail when reusing backup code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_Disable2FA(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Disable 2FA
|
||||||
|
err := tfaAuth.Disable2FA(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Disable2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify 2FA is disabled
|
||||||
|
status, _ := provider.Get2FAStatus(1)
|
||||||
|
if status {
|
||||||
|
t.Error("Disable2FA() did not disable 2FA")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login should not require 2FA
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Requires2FA {
|
||||||
|
t.Error("Login() should not require 2FA after disabling")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoFactorAuthenticator_RegenerateBackupCodes(t *testing.T) {
|
||||||
|
baseAuth := NewMockAuthenticator()
|
||||||
|
provider := security.NewMemoryTwoFactorProvider(nil)
|
||||||
|
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, provider, nil)
|
||||||
|
|
||||||
|
// Setup and enable 2FA
|
||||||
|
secret, _ := tfaAuth.Setup2FA(1, "TestApp", "test@example.com")
|
||||||
|
totp := security.NewTOTPGenerator(nil)
|
||||||
|
code, _ := totp.GenerateCode(secret.Secret, time.Now())
|
||||||
|
tfaAuth.Enable2FA(1, secret.Secret, code)
|
||||||
|
|
||||||
|
// Get initial backup codes
|
||||||
|
codes1, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(codes1) != 10 {
|
||||||
|
t.Errorf("RegenerateBackupCodes() returned %d codes, want 10", len(codes1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regenerate backup codes
|
||||||
|
codes2, err := tfaAuth.RegenerateBackupCodes(1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RegenerateBackupCodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old codes should not work
|
||||||
|
req := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: codes1[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tfaAuth.Login(context.Background(), req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Login() should fail with old backup code after regeneration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// New codes should work
|
||||||
|
req2 := security.LoginRequest{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "password",
|
||||||
|
TwoFactorCode: codes2[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tfaAuth.Login(context.Background(), req2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Login() with new backup code error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Token == "" {
|
||||||
|
t.Error("Login() should return token with new backup code")
|
||||||
|
}
|
||||||
|
}
|
||||||
134
pkg/security/totp_middleware.go
Normal file
134
pkg/security/totp_middleware.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TwoFactorAuthenticator wraps an Authenticator and adds 2FA support
|
||||||
|
type TwoFactorAuthenticator struct {
|
||||||
|
baseAuth Authenticator
|
||||||
|
totp *TOTPGenerator
|
||||||
|
provider TwoFactorAuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTwoFactorAuthenticator creates a new 2FA-enabled authenticator
|
||||||
|
func NewTwoFactorAuthenticator(baseAuth Authenticator, provider TwoFactorAuthProvider, config *TwoFactorConfig) *TwoFactorAuthenticator {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultTwoFactorConfig()
|
||||||
|
}
|
||||||
|
return &TwoFactorAuthenticator{
|
||||||
|
baseAuth: baseAuth,
|
||||||
|
totp: NewTOTPGenerator(config),
|
||||||
|
provider: provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login authenticates with 2FA support
|
||||||
|
func (t *TwoFactorAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
|
// First, perform standard authentication
|
||||||
|
resp, err := t.baseAuth.Login(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user has 2FA enabled
|
||||||
|
if resp.User == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
has2FA, err := t.provider.Get2FAStatus(resp.User.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check 2FA status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !has2FA {
|
||||||
|
// User doesn't have 2FA enabled, return normal response
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// User has 2FA enabled
|
||||||
|
if req.TwoFactorCode == "" {
|
||||||
|
// No 2FA code provided, require it
|
||||||
|
resp.Requires2FA = true
|
||||||
|
resp.Token = "" // Don't return token until 2FA is verified
|
||||||
|
resp.RefreshToken = ""
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 2FA code
|
||||||
|
secret, err := t.provider.Get2FASecret(resp.User.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get 2FA secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try TOTP code first
|
||||||
|
valid, err := t.totp.ValidateCode(secret, req.TwoFactorCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate 2FA code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
// Try backup code
|
||||||
|
valid, err = t.provider.ValidateBackupCode(resp.User.UserID, req.TwoFactorCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate backup code: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
return nil, fmt.Errorf("invalid 2FA code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2FA verified, return full response with token
|
||||||
|
resp.User.TwoFactorEnabled = true
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout delegates to base authenticator
|
||||||
|
func (t *TwoFactorAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
|
return t.baseAuth.Logout(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate delegates to base authenticator
|
||||||
|
func (t *TwoFactorAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
|
return t.baseAuth.Authenticate(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FA initiates 2FA setup for a user
|
||||||
|
func (t *TwoFactorAuthenticator) Setup2FA(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||||
|
return t.provider.Generate2FASecret(userID, issuer, accountName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA completes 2FA setup after user confirms with a valid code
|
||||||
|
func (t *TwoFactorAuthenticator) Enable2FA(userID int, secret, verificationCode string) error {
|
||||||
|
// Verify the code before enabling
|
||||||
|
valid, err := t.totp.ValidateCode(secret, verificationCode)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to validate code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
return fmt.Errorf("invalid verification code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate backup codes
|
||||||
|
backupCodes, err := t.provider.GenerateBackupCodes(userID, 10)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate backup codes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 2FA
|
||||||
|
return t.provider.Enable2FA(userID, secret, backupCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA removes 2FA from a user account
|
||||||
|
func (t *TwoFactorAuthenticator) Disable2FA(userID int) error {
|
||||||
|
return t.provider.Disable2FA(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegenerateBackupCodes creates new backup codes for a user
|
||||||
|
func (t *TwoFactorAuthenticator) RegenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||||
|
return t.provider.GenerateBackupCodes(userID, count)
|
||||||
|
}
|
||||||
229
pkg/security/totp_provider_database.go
Normal file
229
pkg/security/totp_provider_database.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||||
|
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||||
|
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||||
|
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||||
|
// See totp_database_schema.sql for procedure definitions
|
||||||
|
type DatabaseTwoFactorProvider struct {
|
||||||
|
db *sql.DB
|
||||||
|
totpGen *TOTPGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||||
|
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultTwoFactorConfig()
|
||||||
|
}
|
||||||
|
return &DatabaseTwoFactorProvider{
|
||||||
|
db: db,
|
||||||
|
totpGen: NewTOTPGenerator(config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate2FASecret creates a new secret for a user
|
||||||
|
func (p *DatabaseTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||||
|
secret, err := p.totpGen.GenerateSecret()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
qrURL := p.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||||
|
|
||||||
|
backupCodes, err := GenerateBackupCodes(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TwoFactorSecret{
|
||||||
|
Secret: secret,
|
||||||
|
QRCodeURL: qrURL,
|
||||||
|
BackupCodes: backupCodes,
|
||||||
|
Issuer: issuer,
|
||||||
|
AccountName: accountName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate2FACode verifies a TOTP code
|
||||||
|
func (p *DatabaseTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||||
|
return p.totpGen.ValidateCode(secret, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA activates 2FA for a user
|
||||||
|
func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||||
|
// Hash backup codes for secure storage
|
||||||
|
hashedCodes := make([]string, len(backupCodes))
|
||||||
|
for i, code := range backupCodes {
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to JSON array
|
||||||
|
codesJSON, err := json.Marshal(hashedCodes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call stored procedure
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||||
|
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to enable 2FA")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA deactivates 2FA for a user
|
||||||
|
func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||||
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to disable 2FA")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAStatus checks if user has 2FA enabled
|
||||||
|
func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var enabled bool
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||||
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return false, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("failed to get 2FA status")
|
||||||
|
}
|
||||||
|
|
||||||
|
return enabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FASecret retrieves the user's 2FA secret
|
||||||
|
func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var secret sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||||
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return "", fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get 2FA secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !secret.Valid {
|
||||||
|
return "", fmt.Errorf("2FA secret not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return secret.String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes creates backup codes for 2FA
|
||||||
|
func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||||
|
codes, err := GenerateBackupCodes(count)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate backup codes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash backup codes for storage
|
||||||
|
hashedCodes := make([]string, len(codes))
|
||||||
|
for i, code := range codes {
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
hashedCodes[i] = hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to JSON array
|
||||||
|
codesJSON, err := json.Marshal(hashedCodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal backup codes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call stored procedure
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||||
|
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to regenerate backup codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return unhashed codes to user (only time they see them)
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode checks and consumes a backup code
|
||||||
|
func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||||
|
// Hash the code
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
codeHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var valid bool
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||||
|
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return false, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid, nil
|
||||||
|
}
|
||||||
218
pkg/security/totp_provider_database_test.go
Normal file
218
pkg/security/totp_provider_database_test.go
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
package security_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note: These tests require a PostgreSQL database with the schema from totp_database_schema.sql
|
||||||
|
// Set TEST_DATABASE_URL environment variable or skip tests
|
||||||
|
|
||||||
|
func setupTestDB(t *testing.T) *sql.DB {
|
||||||
|
// Skip if no test database configured
|
||||||
|
t.Skip("Database tests require TEST_DATABASE_URL environment variable")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_Enable2FA(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
// Generate secret and backup codes
|
||||||
|
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable 2FA
|
||||||
|
err = provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Enable2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify enabled
|
||||||
|
enabled, err := provider.Get2FAStatus(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
t.Error("Get2FAStatus() = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_Disable2FA(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
// Enable first
|
||||||
|
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||||
|
|
||||||
|
// Disable
|
||||||
|
err := provider.Disable2FA(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Disable2FA() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify disabled
|
||||||
|
enabled, err := provider.Get2FAStatus(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get2FAStatus() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
t.Error("Get2FAStatus() = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_GetSecret(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
// Enable 2FA
|
||||||
|
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||||
|
|
||||||
|
// Retrieve secret
|
||||||
|
retrieved, err := provider.Get2FASecret(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get2FASecret() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved != secret.Secret {
|
||||||
|
t.Errorf("Get2FASecret() = %v, want %v", retrieved, secret.Secret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_ValidateBackupCode(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
// Enable 2FA
|
||||||
|
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||||
|
|
||||||
|
// Validate backup code
|
||||||
|
valid, err := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Error("ValidateBackupCode() = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to use same code again
|
||||||
|
valid, err = provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Error("ValidateBackupCode() should error on reuse")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try invalid code
|
||||||
|
valid, err = provider.ValidateBackupCode(1, "INVALID")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid {
|
||||||
|
t.Error("ValidateBackupCode() = true for invalid code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_RegenerateBackupCodes(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
// Enable 2FA
|
||||||
|
secret, _ := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
provider.Enable2FA(1, secret.Secret, secret.BackupCodes)
|
||||||
|
|
||||||
|
// Regenerate codes
|
||||||
|
newCodes, err := provider.GenerateBackupCodes(1, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GenerateBackupCodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newCodes) != 10 {
|
||||||
|
t.Errorf("GenerateBackupCodes() returned %d codes, want 10", len(newCodes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old codes should not work
|
||||||
|
valid, _ := provider.ValidateBackupCode(1, secret.BackupCodes[0])
|
||||||
|
if valid {
|
||||||
|
t.Error("Old backup code should not work after regeneration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// New codes should work
|
||||||
|
valid, err = provider.ValidateBackupCode(1, newCodes[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ValidateBackupCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Error("ValidateBackupCode() = false for new code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseTwoFactorProvider_Generate2FASecret(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||||
|
|
||||||
|
secret, err := provider.Generate2FASecret(1, "TestApp", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Generate2FASecret() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.Secret == "" {
|
||||||
|
t.Error("Generate2FASecret() returned empty secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.QRCodeURL == "" {
|
||||||
|
t.Error("Generate2FASecret() returned empty QR code URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(secret.BackupCodes) != 10 {
|
||||||
|
t.Errorf("Generate2FASecret() returned %d backup codes, want 10", len(secret.BackupCodes))
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.Issuer != "TestApp" {
|
||||||
|
t.Errorf("Generate2FASecret() Issuer = %v, want TestApp", secret.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.AccountName != "test@example.com" {
|
||||||
|
t.Errorf("Generate2FASecret() AccountName = %v, want test@example.com", secret.AccountName)
|
||||||
|
}
|
||||||
|
}
|
||||||
156
pkg/security/totp_provider_memory.go
Normal file
156
pkg/security/totp_provider_memory.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryTwoFactorProvider is an in-memory implementation of TwoFactorAuthProvider for testing/examples
|
||||||
|
type MemoryTwoFactorProvider struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
secrets map[int]string // userID -> secret
|
||||||
|
backupCodes map[int]map[string]bool // userID -> backup codes (code -> used)
|
||||||
|
totpGen *TOTPGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryTwoFactorProvider creates a new in-memory 2FA provider
|
||||||
|
func NewMemoryTwoFactorProvider(config *TwoFactorConfig) *MemoryTwoFactorProvider {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultTwoFactorConfig()
|
||||||
|
}
|
||||||
|
return &MemoryTwoFactorProvider{
|
||||||
|
secrets: make(map[int]string),
|
||||||
|
backupCodes: make(map[int]map[string]bool),
|
||||||
|
totpGen: NewTOTPGenerator(config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate2FASecret creates a new secret for a user
|
||||||
|
func (m *MemoryTwoFactorProvider) Generate2FASecret(userID int, issuer, accountName string) (*TwoFactorSecret, error) {
|
||||||
|
secret, err := m.totpGen.GenerateSecret()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
qrURL := m.totpGen.GenerateQRCodeURL(secret, issuer, accountName)
|
||||||
|
|
||||||
|
backupCodes, err := GenerateBackupCodes(10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TwoFactorSecret{
|
||||||
|
Secret: secret,
|
||||||
|
QRCodeURL: qrURL,
|
||||||
|
BackupCodes: backupCodes,
|
||||||
|
Issuer: issuer,
|
||||||
|
AccountName: accountName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate2FACode verifies a TOTP code
|
||||||
|
func (m *MemoryTwoFactorProvider) Validate2FACode(secret string, code string) (bool, error) {
|
||||||
|
return m.totpGen.ValidateCode(secret, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA activates 2FA for a user
|
||||||
|
func (m *MemoryTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.secrets[userID] = secret
|
||||||
|
|
||||||
|
// Store backup codes
|
||||||
|
if m.backupCodes[userID] == nil {
|
||||||
|
m.backupCodes[userID] = make(map[string]bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, code := range backupCodes {
|
||||||
|
// Hash backup codes for security
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA deactivates 2FA for a user
|
||||||
|
func (m *MemoryTwoFactorProvider) Disable2FA(userID int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
delete(m.secrets, userID)
|
||||||
|
delete(m.backupCodes, userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAStatus checks if user has 2FA enabled
|
||||||
|
func (m *MemoryTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
_, exists := m.secrets[userID]
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FASecret retrieves the user's 2FA secret
|
||||||
|
func (m *MemoryTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
secret, exists := m.secrets[userID]
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("user does not have 2FA enabled")
|
||||||
|
}
|
||||||
|
return secret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes creates backup codes for 2FA
|
||||||
|
func (m *MemoryTwoFactorProvider) GenerateBackupCodes(userID int, count int) ([]string, error) {
|
||||||
|
codes, err := GenerateBackupCodes(count)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Clear old backup codes and store new ones
|
||||||
|
m.backupCodes[userID] = make(map[string]bool)
|
||||||
|
for _, code := range codes {
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
m.backupCodes[userID][hex.EncodeToString(hash[:])] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode checks and consumes a backup code
|
||||||
|
func (m *MemoryTwoFactorProvider) ValidateBackupCode(userID int, code string) (bool, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
userCodes, exists := m.backupCodes[userID]
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the provided code
|
||||||
|
hash := sha256.Sum256([]byte(code))
|
||||||
|
hashStr := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
used, exists := userCodes[hashStr]
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if used {
|
||||||
|
return false, fmt.Errorf("backup code already used")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as used
|
||||||
|
userCodes[hashStr] = true
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
292
pkg/security/totp_test.go
Normal file
292
pkg/security/totp_test.go
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTOTPGenerator_GenerateSecret(t *testing.T) {
|
||||||
|
totp := NewTOTPGenerator(nil)
|
||||||
|
|
||||||
|
secret, err := totp.GenerateSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateSecret() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret == "" {
|
||||||
|
t.Error("GenerateSecret() returned empty secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Secret should be base32 encoded
|
||||||
|
if len(secret) < 16 {
|
||||||
|
t.Error("GenerateSecret() returned secret that is too short")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_GenerateQRCodeURL(t *testing.T) {
|
||||||
|
totp := NewTOTPGenerator(nil)
|
||||||
|
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
issuer := "TestApp"
|
||||||
|
accountName := "user@example.com"
|
||||||
|
|
||||||
|
url := totp.GenerateQRCodeURL(secret, issuer, accountName)
|
||||||
|
|
||||||
|
if !strings.HasPrefix(url, "otpauth://totp/") {
|
||||||
|
t.Errorf("GenerateQRCodeURL() = %v, want otpauth://totp/ prefix", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(url, "secret="+secret) {
|
||||||
|
t.Errorf("GenerateQRCodeURL() missing secret parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(url, "issuer="+issuer) {
|
||||||
|
t.Errorf("GenerateQRCodeURL() missing issuer parameter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_GenerateCode(t *testing.T) {
|
||||||
|
config := &TwoFactorConfig{
|
||||||
|
Algorithm: "SHA1",
|
||||||
|
Digits: 6,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 1,
|
||||||
|
}
|
||||||
|
totp := NewTOTPGenerator(config)
|
||||||
|
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
// Test with known time
|
||||||
|
timestamp := time.Unix(1234567890, 0)
|
||||||
|
code, err := totp.GenerateCode(secret, timestamp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(code) != 6 {
|
||||||
|
t.Errorf("GenerateCode() returned code with length %d, want 6", len(code))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code should be numeric
|
||||||
|
for _, c := range code {
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
t.Errorf("GenerateCode() returned non-numeric code: %s", code)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_ValidateCode(t *testing.T) {
|
||||||
|
config := &TwoFactorConfig{
|
||||||
|
Algorithm: "SHA1",
|
||||||
|
Digits: 6,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 1,
|
||||||
|
}
|
||||||
|
totp := NewTOTPGenerator(config)
|
||||||
|
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
// Generate a code for current time
|
||||||
|
now := time.Now()
|
||||||
|
code, err := totp.GenerateCode(secret, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the code
|
||||||
|
valid, err := totp.ValidateCode(secret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Error("ValidateCode() = false, want true for current code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with invalid code
|
||||||
|
valid, err = totp.ValidateCode(secret, "000000")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This might occasionally pass if 000000 is the correct code, but very unlikely
|
||||||
|
if valid && code != "000000" {
|
||||||
|
t.Error("ValidateCode() = true for invalid code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_ValidateCode_WithSkew(t *testing.T) {
|
||||||
|
config := &TwoFactorConfig{
|
||||||
|
Algorithm: "SHA1",
|
||||||
|
Digits: 6,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 2, // Allow 2 periods before/after
|
||||||
|
}
|
||||||
|
totp := NewTOTPGenerator(config)
|
||||||
|
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
// Generate code for 1 period ago
|
||||||
|
past := time.Now().Add(-30 * time.Second)
|
||||||
|
code, err := totp.GenerateCode(secret, past)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should still validate with skew window
|
||||||
|
valid, err := totp.ValidateCode(secret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Error("ValidateCode() = false, want true for code within skew window")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_DifferentAlgorithms(t *testing.T) {
|
||||||
|
algorithms := []string{"SHA1", "SHA256", "SHA512"}
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
for _, algo := range algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
config := &TwoFactorConfig{
|
||||||
|
Algorithm: algo,
|
||||||
|
Digits: 6,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 1,
|
||||||
|
}
|
||||||
|
totp := NewTOTPGenerator(config)
|
||||||
|
|
||||||
|
code, err := totp.GenerateCode(secret, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() with %s error = %v", algo, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := totp.ValidateCode(secret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateCode() with %s error = %v", algo, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Errorf("ValidateCode() with %s = false, want true", algo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_8Digits(t *testing.T) {
|
||||||
|
config := &TwoFactorConfig{
|
||||||
|
Algorithm: "SHA1",
|
||||||
|
Digits: 8,
|
||||||
|
Period: 30,
|
||||||
|
SkewWindow: 1,
|
||||||
|
}
|
||||||
|
totp := NewTOTPGenerator(config)
|
||||||
|
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
|
||||||
|
code, err := totp.GenerateCode(secret, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(code) != 8 {
|
||||||
|
t.Errorf("GenerateCode() returned code with length %d, want 8", len(code))
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := totp.ValidateCode(secret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateCode() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
t.Error("ValidateCode() = false, want true for 8-digit code")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateBackupCodes(t *testing.T) {
|
||||||
|
count := 10
|
||||||
|
codes, err := GenerateBackupCodes(count)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateBackupCodes() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(codes) != count {
|
||||||
|
t.Errorf("GenerateBackupCodes() returned %d codes, want %d", len(codes), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check uniqueness
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, code := range codes {
|
||||||
|
if seen[code] {
|
||||||
|
t.Errorf("GenerateBackupCodes() generated duplicate code: %s", code)
|
||||||
|
}
|
||||||
|
seen[code] = true
|
||||||
|
|
||||||
|
// Check format (8 hex characters)
|
||||||
|
if len(code) != 8 {
|
||||||
|
t.Errorf("GenerateBackupCodes() code length = %d, want 8", len(code))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultTwoFactorConfig(t *testing.T) {
|
||||||
|
config := DefaultTwoFactorConfig()
|
||||||
|
|
||||||
|
if config.Algorithm != "SHA1" {
|
||||||
|
t.Errorf("DefaultTwoFactorConfig() Algorithm = %s, want SHA1", config.Algorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Digits != 6 {
|
||||||
|
t.Errorf("DefaultTwoFactorConfig() Digits = %d, want 6", config.Digits)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Period != 30 {
|
||||||
|
t.Errorf("DefaultTwoFactorConfig() Period = %d, want 30", config.Period)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.SkewWindow != 1 {
|
||||||
|
t.Errorf("DefaultTwoFactorConfig() SkewWindow = %d, want 1", config.SkewWindow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOTPGenerator_InvalidSecret(t *testing.T) {
|
||||||
|
totp := NewTOTPGenerator(nil)
|
||||||
|
|
||||||
|
// Test with invalid base32 secret
|
||||||
|
_, err := totp.GenerateCode("INVALID!!!", time.Now())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("GenerateCode() with invalid secret should return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = totp.ValidateCode("INVALID!!!", "123456")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("ValidateCode() with invalid secret should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkTOTPGenerator_GenerateCode(b *testing.B) {
|
||||||
|
totp := NewTOTPGenerator(nil)
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = totp.GenerateCode(secret, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTOTPGenerator_ValidateCode(b *testing.B) {
|
||||||
|
totp := NewTOTPGenerator(nil)
|
||||||
|
secret := "JBSWY3DPEHPK3PXP"
|
||||||
|
code, _ := totp.GenerateCode(secret, time.Now())
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = totp.ValidateCode(secret, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
47
pkg/server/config_helper.go
Normal file
47
pkg/server/config_helper.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromConfigInstanceToServerConfig converts a config.ServerInstanceConfig to server.Config
|
||||||
|
// The handler must be provided separately as it cannot be serialized
|
||||||
|
func FromConfigInstanceToServerConfig(sic *config.ServerInstanceConfig, handler http.Handler) Config {
|
||||||
|
cfg := Config{
|
||||||
|
Name: sic.Name,
|
||||||
|
Host: sic.Host,
|
||||||
|
Port: sic.Port,
|
||||||
|
Description: sic.Description,
|
||||||
|
Handler: handler,
|
||||||
|
GZIP: sic.GZIP,
|
||||||
|
|
||||||
|
SSLCert: sic.SSLCert,
|
||||||
|
SSLKey: sic.SSLKey,
|
||||||
|
SelfSignedSSL: sic.SelfSignedSSL,
|
||||||
|
AutoTLS: sic.AutoTLS,
|
||||||
|
AutoTLSDomains: sic.AutoTLSDomains,
|
||||||
|
AutoTLSCacheDir: sic.AutoTLSCacheDir,
|
||||||
|
AutoTLSEmail: sic.AutoTLSEmail,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply timeouts (use pointers to override, or use zero values for defaults)
|
||||||
|
if sic.ShutdownTimeout != nil {
|
||||||
|
cfg.ShutdownTimeout = *sic.ShutdownTimeout
|
||||||
|
}
|
||||||
|
if sic.DrainTimeout != nil {
|
||||||
|
cfg.DrainTimeout = *sic.DrainTimeout
|
||||||
|
}
|
||||||
|
if sic.ReadTimeout != nil {
|
||||||
|
cfg.ReadTimeout = *sic.ReadTimeout
|
||||||
|
}
|
||||||
|
if sic.WriteTimeout != nil {
|
||||||
|
cfg.WriteTimeout = *sic.WriteTimeout
|
||||||
|
}
|
||||||
|
if sic.IdleTimeout != nil {
|
||||||
|
cfg.IdleTimeout = *sic.IdleTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user