mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 16:24:26 +00:00
Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e1547ec65 | ||
|
|
49cdc6f17b | ||
|
|
0bd653820c | ||
|
|
9209193157 | ||
|
|
b8c44c5a99 | ||
|
|
28fd88fff1 | ||
|
|
be38341383 | ||
|
|
fab744b878 | ||
|
|
5ad2bd3a78 | ||
|
|
333fe158e9 | ||
|
|
2a2d351ad4 | ||
|
|
e918c49b84 | ||
|
|
41e4956510 | ||
|
|
8e8c3c6de6 | ||
|
|
aa9b7312f6 | ||
|
|
dca43b0e05 | ||
|
|
6f368bbce5 | ||
|
|
8704cee941 | ||
|
|
4ce5afe0ac | ||
|
|
7b98ea2145 | ||
|
|
897cb2ae0d | ||
|
|
01420e6b63 | ||
|
|
645907d355 | ||
|
|
e81d7b48cc | ||
|
|
8f5a725a09 | ||
|
|
3d5d7b788e | ||
|
|
eaecef686e | ||
|
|
e0d21b17ec | ||
|
|
7e1718e864 | ||
|
|
16d416030e | ||
|
|
bf8500714a | ||
|
|
4f8edd6469 | ||
|
|
ccf8522f88 | ||
|
|
92a83e9cc6 | ||
|
|
4cb35a78b0 | ||
|
|
e10e2e1c27 | ||
|
|
64f56325d4 | ||
|
|
5e6032c91d | ||
|
|
bc2fdc143b | ||
|
|
267e84fd84 | ||
|
|
8adc386863 | ||
|
|
feb023ec48 | ||
|
|
de50141a04 | ||
|
|
c226dc349f | ||
|
|
d4a6f9c4c2 | ||
| 8f83e8fdc1 | |||
|
|
90df4a157c | ||
|
|
2dd404af96 | ||
|
|
17c472b206 | ||
|
|
ed67caf055 | ||
| 4d1b8b6982 | |||
|
|
63ed62a9a3 | ||
|
|
0525323a47 | ||
|
|
c3443f702e | ||
|
|
45c463c117 | ||
|
|
84d673ce14 | ||
|
|
02fbdbd651 | ||
|
|
97988e3b5e | ||
|
|
c9838ad9d2 | ||
|
|
c5c0608f63 | ||
|
|
39c3f05d21 | ||
|
|
4ecd1ac17e | ||
|
|
2b1aea0338 | ||
|
|
1e749efeb3 | ||
|
|
09be676096 | ||
|
|
e8350a70be | ||
|
|
5937b9eab5 | ||
|
|
7c861c708e | ||
|
|
77f39af2f9 | ||
|
|
fbc1471581 | ||
|
|
9351093e2a | ||
|
|
932f12ab0a | ||
|
|
1b2b0d8f0b | ||
|
|
b22792bad6 | ||
|
|
e8111c01aa | ||
|
|
5862016031 | ||
|
|
2f18dde29c | ||
|
|
31ad217818 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -25,3 +25,4 @@ go.work.sum
|
||||
.env
|
||||
bin/
|
||||
test.db
|
||||
testserver
|
||||
|
||||
51
Makefile
51
Makefile
@@ -13,10 +13,55 @@ test-integration:
|
||||
# Run all tests (unit + integration)
|
||||
test: test-unit test-integration
|
||||
|
||||
release-version: ## Create and push a release with specific version (use: make release-version VERSION=v1.2.3)
|
||||
@if [ -z "$(VERSION)" ]; then \
|
||||
echo "Error: VERSION is required. Usage: make release-version VERSION=v1.2.3"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@version="$(VERSION)"; \
|
||||
if ! echo "$$version" | grep -q "^v"; then \
|
||||
version="v$$version"; \
|
||||
fi; \
|
||||
echo "Creating release: $$version"; \
|
||||
latest_tag=$$(git describe --tags --abbrev=0 2>/dev/null || echo ""); \
|
||||
if [ -z "$$latest_tag" ]; then \
|
||||
commit_logs=$$(git log --pretty=format:"- %s" --no-merges); \
|
||||
else \
|
||||
commit_logs=$$(git log "$${latest_tag}..HEAD" --pretty=format:"- %s" --no-merges); \
|
||||
fi; \
|
||||
if [ -z "$$commit_logs" ]; then \
|
||||
tag_message="Release $$version"; \
|
||||
else \
|
||||
tag_message="Release $$version\n\n$$commit_logs"; \
|
||||
fi; \
|
||||
git tag -a "$$version" -m "$$tag_message"; \
|
||||
git push origin "$$version"; \
|
||||
echo "Tag $$version created and pushed to remote repository."
|
||||
|
||||
|
||||
lint: ## Run linter
|
||||
@echo "Running linter..."
|
||||
@if command -v golangci-lint > /dev/null; then \
|
||||
golangci-lint run --config=.golangci.json; \
|
||||
else \
|
||||
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
lintfix: ## Run linter
|
||||
@echo "Running linter..."
|
||||
@if command -v golangci-lint > /dev/null; then \
|
||||
golangci-lint run --config=.golangci.json --fix; \
|
||||
else \
|
||||
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
|
||||
# Start PostgreSQL for integration tests
|
||||
docker-up:
|
||||
@echo "Starting PostgreSQL container..."
|
||||
@docker-compose up -d postgres-test
|
||||
@podman compose up -d postgres-test
|
||||
@echo "Waiting for PostgreSQL to be ready..."
|
||||
@sleep 5
|
||||
@echo "PostgreSQL is ready!"
|
||||
@@ -24,12 +69,12 @@ docker-up:
|
||||
# Stop PostgreSQL container
|
||||
docker-down:
|
||||
@echo "Stopping PostgreSQL container..."
|
||||
@docker-compose down
|
||||
@podman compose down
|
||||
|
||||
# Clean up Docker volumes and test data
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@docker-compose down -v
|
||||
@podman compose down -v
|
||||
@echo "Cleanup complete!"
|
||||
|
||||
# Run integration tests with Docker (full workflow)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -67,9 +67,36 @@ func main() {
|
||||
// Setup routes using new SetupMuxRoutes function (without authentication)
|
||||
resolvespec.SetupMuxRoutes(r, handler, nil)
|
||||
|
||||
// Create graceful server with configuration
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: cfg.Server.Addr,
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Parse host and port from addr
|
||||
host := ""
|
||||
port := 8080
|
||||
if cfg.Server.Addr != "" {
|
||||
// Parse addr (format: ":8080" or "localhost:8080")
|
||||
if cfg.Server.Addr[0] == ':' {
|
||||
// Just port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
// Host and port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add server instance
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "api",
|
||||
Host: host,
|
||||
Port: port,
|
||||
Handler: r,
|
||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
||||
DrainTimeout: cfg.Server.DrainTimeout,
|
||||
@@ -77,11 +104,15 @@ func main() {
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to add server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Start server with graceful shutdown
|
||||
logger.Info("Starting server on %s", cfg.Server.Addr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Server failed to start: %v", err)
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
logger.Error("Server failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
79
go.mod
79
go.mod
@@ -7,19 +7,26 @@ toolchain go1.24.6
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1
|
||||
github.com/getsentry/sentry-go v0.40.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.6.0
|
||||
github.com/klauspost/compress v1.18.0
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
github.com/nats-io/nats.go v1.48.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.40.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/uptrace/bun v1.2.15
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
||||
github.com/uptrace/bun v1.2.16
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
||||
github.com/uptrace/bunrouter v1.0.23
|
||||
go.opentelemetry.io/otel v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
||||
@@ -27,71 +34,113 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.25.12
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/go-archive v0.1.0 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||
github.com/moby/sys/user v0.4.0 // indirect
|
||||
github.com/moby/sys/userns v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/nats-io/nkeys v0.4.11 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rs/xid v1.4.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||
golang.org/x/net v0.45.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/libc v1.67.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
modernc.org/sqlite v1.40.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||
|
||||
192
go.sum
192
go.sum
@@ -1,5 +1,13 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||
@@ -8,17 +16,45 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE=
|
||||
github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
@@ -36,10 +72,13 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
@@ -48,8 +87,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -58,6 +101,8 @@ github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -71,14 +116,48 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI=
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9/go.mod h1:lZD3j35AVNqJL5cezlnSkuG05c0FCHSsfAKSPBOSbqc=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
|
||||
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
|
||||
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
|
||||
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
@@ -87,6 +166,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
@@ -103,8 +184,14 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
@@ -116,12 +203,16 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -131,28 +222,38 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
|
||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
@@ -173,25 +274,34 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
||||
@@ -210,20 +320,26 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
||||
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
@@ -232,8 +348,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
20
pkg/cache/cache_manager.go
vendored
20
pkg/cache/cache_manager.go
vendored
@@ -57,11 +57,31 @@ func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time
|
||||
return c.provider.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// SetWithTags serializes and stores a value in the cache with the specified TTL and tags.
|
||||
func (c *Cache) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags []string) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize: %w", err)
|
||||
}
|
||||
|
||||
return c.provider.SetWithTags(ctx, key, data, ttl, tags)
|
||||
}
|
||||
|
||||
// SetBytesWithTags stores raw bytes in the cache with the specified TTL and tags.
|
||||
func (c *Cache) SetBytesWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
return c.provider.SetWithTags(ctx, key, value, ttl, tags)
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||
return c.provider.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (c *Cache) DeleteByTag(ctx context.Context, tag string) error {
|
||||
return c.provider.DeleteByTag(ctx, tag)
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
return c.provider.DeleteByPattern(ctx, pattern)
|
||||
|
||||
8
pkg/cache/provider.go
vendored
8
pkg/cache/provider.go
vendored
@@ -15,9 +15,17 @@ type Provider interface {
|
||||
// If ttl is 0, the item never expires.
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
// Tags can be used to invalidate groups of related keys.
|
||||
// If ttl is 0, the item never expires.
|
||||
SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
DeleteByTag(ctx context.Context, tag string) error
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Pattern syntax depends on the provider implementation.
|
||||
DeleteByPattern(ctx context.Context, pattern string) error
|
||||
|
||||
140
pkg/cache/provider_memcache.go
vendored
140
pkg/cache/provider_memcache.go
vendored
@@ -2,6 +2,7 @@ package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -97,8 +98,115 @@ func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, tt
|
||||
return m.client.Set(item)
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
// Note: Tag support in Memcache is limited and less efficient than Redis.
|
||||
func (m *MemcacheProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
expiration := int32(ttl.Seconds())
|
||||
|
||||
// Set the main value
|
||||
item := &memcache.Item{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
}
|
||||
if err := m.client.Set(item); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store tags for this key
|
||||
if len(tags) > 0 {
|
||||
tagsData, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||
}
|
||||
|
||||
tagsItem := &memcache.Item{
|
||||
Key: fmt.Sprintf("cache:tags:%s", key),
|
||||
Value: tagsData,
|
||||
Expiration: expiration,
|
||||
}
|
||||
if err := m.client.Set(tagsItem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add key to each tag's key list
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get existing keys for this tag
|
||||
var keys []string
|
||||
if item, err := m.client.Get(tagKey); err == nil {
|
||||
_ = json.Unmarshal(item.Value, &keys)
|
||||
}
|
||||
|
||||
// Add current key if not already present
|
||||
found := false
|
||||
for _, k := range keys {
|
||||
if k == key {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
// Store updated key list
|
||||
keysData, err := json.Marshal(keys)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tagItem := &memcache.Item{
|
||||
Key: tagKey,
|
||||
Value: keysData,
|
||||
Expiration: expiration + 3600, // Give tag lists longer TTL
|
||||
}
|
||||
_ = m.client.Set(tagItem)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||
// Get tags for this key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
if item, err := m.client.Get(tagsKey); err == nil {
|
||||
var tags []string
|
||||
if err := json.Unmarshal(item.Value, &tags); err == nil {
|
||||
// Remove key from each tag's key list
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
if tagItem, err := m.client.Get(tagKey); err == nil {
|
||||
var keys []string
|
||||
if err := json.Unmarshal(tagItem.Value, &keys); err == nil {
|
||||
// Remove current key from the list
|
||||
newKeys := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
if k != key {
|
||||
newKeys = append(newKeys, k)
|
||||
}
|
||||
}
|
||||
// Update the tag's key list
|
||||
if keysData, err := json.Marshal(newKeys); err == nil {
|
||||
tagItem.Value = keysData
|
||||
_ = m.client.Set(tagItem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Delete the tags key
|
||||
_ = m.client.Delete(tagsKey)
|
||||
}
|
||||
|
||||
// Delete the actual key
|
||||
err := m.client.Delete(key)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil
|
||||
@@ -106,6 +214,38 @@ func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (m *MemcacheProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get all keys associated with this tag
|
||||
item, err := m.client.Get(tagKey)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
if err := json.Unmarshal(item.Value, &keys); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal tag keys: %w", err)
|
||||
}
|
||||
|
||||
// Delete all keys
|
||||
for _, key := range keys {
|
||||
_ = m.client.Delete(key)
|
||||
// Also delete the tags key for this cache key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
_ = m.client.Delete(tagsKey)
|
||||
}
|
||||
|
||||
// Delete the tag key itself
|
||||
_ = m.client.Delete(tagKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Note: Memcache does not support pattern-based deletion natively.
|
||||
// This is a no-op for memcache and returns an error.
|
||||
|
||||
118
pkg/cache/provider_memory.go
vendored
118
pkg/cache/provider_memory.go
vendored
@@ -15,6 +15,7 @@ type memoryItem struct {
|
||||
Expiration time.Time
|
||||
LastAccess time.Time
|
||||
HitCount int64
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// isExpired checks if the item has expired.
|
||||
@@ -27,11 +28,12 @@ func (m *memoryItem) isExpired() bool {
|
||||
|
||||
// MemoryProvider is an in-memory implementation of the Provider interface.
|
||||
type MemoryProvider struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryItem
|
||||
options *Options
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryItem
|
||||
tagToKeys map[string]map[string]struct{} // tag -> set of keys
|
||||
options *Options
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
}
|
||||
|
||||
// NewMemoryProvider creates a new in-memory cache provider.
|
||||
@@ -44,8 +46,9 @@ func NewMemoryProvider(opts *Options) *MemoryProvider {
|
||||
}
|
||||
|
||||
return &MemoryProvider{
|
||||
items: make(map[string]*memoryItem),
|
||||
options: opts,
|
||||
items: make(map[string]*memoryItem),
|
||||
tagToKeys: make(map[string]map[string]struct{}),
|
||||
options: opts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,15 +117,116 @@ func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
func (m *MemoryProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
var expiration time.Time
|
||||
if ttl > 0 {
|
||||
expiration = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// Check max size and evict if necessary
|
||||
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||
if _, exists := m.items[key]; !exists {
|
||||
m.evictOne()
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old tag associations if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
for _, tag := range oldItem.Tags {
|
||||
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||
delete(keySet, key)
|
||||
if len(keySet) == 0 {
|
||||
delete(m.tagToKeys, tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store the item
|
||||
m.items[key] = &memoryItem{
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
LastAccess: time.Now(),
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
// Add new tag associations
|
||||
for _, tag := range tags {
|
||||
if m.tagToKeys[tag] == nil {
|
||||
m.tagToKeys[tag] = make(map[string]struct{})
|
||||
}
|
||||
m.tagToKeys[tag][key] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Remove tag associations
|
||||
if item, exists := m.items[key]; exists {
|
||||
for _, tag := range item.Tags {
|
||||
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||
delete(keySet, key)
|
||||
if len(keySet) == 0 {
|
||||
delete(m.tagToKeys, tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(m.items, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (m *MemoryProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Get all keys associated with this tag
|
||||
keySet, exists := m.tagToKeys[tag]
|
||||
if !exists {
|
||||
return nil // No keys with this tag
|
||||
}
|
||||
|
||||
// Delete all items with this tag
|
||||
for key := range keySet {
|
||||
if item, ok := m.items[key]; ok {
|
||||
// Remove this tag from the item's tag list
|
||||
newTags := make([]string, 0, len(item.Tags))
|
||||
for _, t := range item.Tags {
|
||||
if t != tag {
|
||||
newTags = append(newTags, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If item has no more tags, delete it
|
||||
// Otherwise update its tags
|
||||
if len(newTags) == 0 {
|
||||
delete(m.items, key)
|
||||
} else {
|
||||
item.Tags = newTags
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the tag mapping
|
||||
delete(m.tagToKeys, tag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
86
pkg/cache/provider_redis.go
vendored
86
pkg/cache/provider_redis.go
vendored
@@ -103,9 +103,93 @@ func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl t
|
||||
return r.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||
func (r *RedisProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||
if ttl == 0 {
|
||||
ttl = r.options.DefaultTTL
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Set the value
|
||||
pipe.Set(ctx, key, value, ttl)
|
||||
|
||||
// Add key to each tag's set
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
pipe.SAdd(ctx, tagKey, key)
|
||||
// Set expiration on tag set (longer than cache items to ensure cleanup)
|
||||
if ttl > 0 {
|
||||
pipe.Expire(ctx, tagKey, ttl+time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// Store tags for this key for later cleanup
|
||||
if len(tags) > 0 {
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
pipe.SAdd(ctx, tagsKey, tags)
|
||||
if ttl > 0 {
|
||||
pipe.Expire(ctx, tagsKey, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||
return r.client.Del(ctx, key).Err()
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Get tags for this key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
tags, err := r.client.SMembers(ctx, tagsKey).Result()
|
||||
if err == nil && len(tags) > 0 {
|
||||
// Remove key from each tag set
|
||||
for _, tag := range tags {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
pipe.SRem(ctx, tagKey, key)
|
||||
}
|
||||
// Delete the tags key
|
||||
pipe.Del(ctx, tagsKey)
|
||||
}
|
||||
|
||||
// Delete the actual key
|
||||
pipe.Del(ctx, key)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByTag removes all keys associated with the given tag.
|
||||
func (r *RedisProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||
|
||||
// Get all keys associated with this tag
|
||||
keys, err := r.client.SMembers(ctx, tagKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Delete all keys and their tag associations
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
// Also delete the tags key for this cache key
|
||||
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||
pipe.Del(ctx, tagsKey)
|
||||
}
|
||||
|
||||
// Delete the tag set itself
|
||||
pipe.Del(ctx, tagKey)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
|
||||
151
pkg/cache/query_cache_test.go
vendored
151
pkg/cache/query_cache_test.go
vendored
@@ -1,151 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestBuildQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
{Column: "age", Operator: "gt", Value: 25},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
||||
}
|
||||
|
||||
// Different parameters should generate different key
|
||||
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
expandOpts := []interface{}{
|
||||
map[string]interface{}{
|
||||
"relation": "posts",
|
||||
"where": "status = 'published'",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters")
|
||||
}
|
||||
|
||||
// Different distinct value should generate different key
|
||||
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different distinct values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQueryTotalCacheKey(t *testing.T) {
|
||||
hash := "abc123"
|
||||
key := GetQueryTotalCacheKey(hash)
|
||||
|
||||
expected := "query_total:abc123"
|
||||
if key != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedTotalIntegration(t *testing.T) {
|
||||
// Initialize cache with memory provider for testing
|
||||
UseMemory(&Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 100,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "created_at", Direction: "desc"},
|
||||
}
|
||||
|
||||
// Build cache key
|
||||
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
||||
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Store a total count in cache
|
||||
totalToCache := CachedTotal{Total: 42}
|
||||
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set cache: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve from cache
|
||||
var cachedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get from cache: %v", err)
|
||||
}
|
||||
|
||||
if cachedTotal.Total != 42 {
|
||||
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
||||
}
|
||||
|
||||
// Test cache miss
|
||||
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
||||
var missedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for cache miss, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashString(t *testing.T) {
|
||||
input1 := "test string"
|
||||
input2 := "test string"
|
||||
input3 := "different string"
|
||||
|
||||
hash1 := hashString(input1)
|
||||
hash2 := hashString(input2)
|
||||
hash3 := hashString(input3)
|
||||
|
||||
// Same input should produce same hash
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Expected same hash for identical inputs")
|
||||
}
|
||||
|
||||
// Different input should produce different hash
|
||||
if hash1 == hash3 {
|
||||
t.Errorf("Expected different hash for different inputs")
|
||||
}
|
||||
|
||||
// Hash should be hex encoded SHA256 (64 characters)
|
||||
if len(hash1) != 64 {
|
||||
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
||||
}
|
||||
}
|
||||
@@ -196,6 +196,10 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
})
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
}
|
||||
|
||||
// BunSelectQuery implements SelectQuery for Bun
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
@@ -687,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.OrderExpr(order, args...)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
|
||||
b.query = b.query.Limit(n)
|
||||
return b
|
||||
@@ -1208,3 +1217,7 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error {
|
||||
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
return fn(b) // Already in transaction
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.tx
|
||||
}
|
||||
|
||||
@@ -102,6 +102,10 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
})
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
}
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
@@ -382,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
// GORM's Order can handle expressions directly
|
||||
g.db = g.db.Order(gorm.Expr(order, args...))
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
|
||||
g.db = g.db.Limit(n)
|
||||
return g
|
||||
|
||||
@@ -137,6 +137,10 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return fn(adapter)
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.db
|
||||
}
|
||||
|
||||
// preloadConfig represents a relationship to be preloaded
|
||||
type preloadConfig struct {
|
||||
relation string
|
||||
@@ -277,6 +281,13 @@ func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
// For PgSQL, expressions are passed directly without quoting
|
||||
// If there are args, we would need to format them, but for now just append the expression
|
||||
p.orderBy = append(p.orderBy, order)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
|
||||
p.limit = n
|
||||
return p
|
||||
@@ -897,6 +908,10 @@ func (p *PgSQLTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Da
|
||||
return fn(p)
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.tx
|
||||
}
|
||||
|
||||
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
||||
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
||||
for _, preload := range p.preloads {
|
||||
|
||||
@@ -24,6 +24,12 @@ type Database interface {
|
||||
CommitTx(ctx context.Context) error
|
||||
RollbackTx(ctx context.Context) error
|
||||
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
||||
|
||||
// GetUnderlyingDB returns the underlying database connection
|
||||
// For GORM, this returns *gorm.DB
|
||||
// For Bun, this returns *bun.DB
|
||||
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||
GetUnderlyingDB() interface{}
|
||||
}
|
||||
|
||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||
@@ -40,6 +46,7 @@ type SelectQuery interface {
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
OrderExpr(order string, args ...interface{}) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
Group(group string) SelectQuery
|
||||
|
||||
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -208,6 +209,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
}
|
||||
}
|
||||
}
|
||||
// Note: We no longer add prefixes to unqualified columns here.
|
||||
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
|
||||
|
||||
validConditions = append(validConditions, cond)
|
||||
}
|
||||
@@ -231,35 +234,52 @@ func stripOuterParentheses(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
for {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
stripped, wasStripped := stripOneMatchingOuterParen(s)
|
||||
if !wasStripped {
|
||||
return s
|
||||
}
|
||||
s = stripped
|
||||
}
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s
|
||||
}
|
||||
// stripOneOuterParentheses removes only one level of matching outer parentheses from a string
|
||||
// Unlike stripOuterParentheses, this only strips once, preserving nested parentheses
|
||||
func stripOneOuterParentheses(s string) string {
|
||||
stripped, _ := stripOneMatchingOuterParen(strings.TrimSpace(s))
|
||||
return stripped
|
||||
}
|
||||
|
||||
// stripOneMatchingOuterParen is a helper that strips one matching pair of outer parentheses
|
||||
// Returns the stripped string and a boolean indicating if stripping occurred
|
||||
func stripOneMatchingOuterParen(s string) (string, bool) {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
return s, false
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s, false
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s
|
||||
}
|
||||
|
||||
// Strip the outer parentheses and continue
|
||||
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s, false
|
||||
}
|
||||
|
||||
// Strip the outer parentheses
|
||||
return strings.TrimSpace(s[1 : len(s)-1]), true
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
@@ -483,6 +503,87 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Unused: extractUnqualifiedColumnName extracts the column name from an unqualified condition
|
||||
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
|
||||
// "status = 'active'" returns "status"
|
||||
// nolint:unused
|
||||
func extractUnqualifiedColumnName(cond string) string {
|
||||
// Common SQL operators
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
minIdx := -1
|
||||
for _, op := range operators {
|
||||
idx := strings.Index(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
var columnRef string
|
||||
if minIdx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||
} else {
|
||||
// No operator found, might be a single column reference
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef = parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
if columnRef == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Return empty if it contains a dot (already qualified) or function call
|
||||
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return columnRef
|
||||
}
|
||||
|
||||
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
|
||||
// Uses word boundaries to avoid partial matches
|
||||
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
|
||||
// returns "table.rid_item is null"
|
||||
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
|
||||
// Use word boundary matching with Go's supported regex syntax
|
||||
// \b matches word boundaries
|
||||
escapedOld := regexp.QuoteMeta(oldRef)
|
||||
pattern := `\b` + escapedOld + `\b`
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// If regex fails, fall back to simple string replacement
|
||||
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
|
||||
return strings.Replace(cond, oldRef, newRef, 1)
|
||||
}
|
||||
|
||||
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
|
||||
result := cond
|
||||
matches := re.FindAllStringIndex(cond, -1)
|
||||
|
||||
// Process matches in reverse order to maintain correct indices
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
match := matches[i]
|
||||
start := match[0]
|
||||
|
||||
// Check if preceded by a dot (already qualified)
|
||||
if start > 0 && cond[start-1] == '.' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Replace this occurrence
|
||||
result = result[:start] + newRef + result[match[1]:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||
@@ -538,3 +639,173 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
}
|
||||
return validColumns[strings.ToLower(columnName)]
|
||||
}
|
||||
|
||||
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
|
||||
// This function only prefixes simple column references and skips:
|
||||
// - Columns already having a table prefix (containing a dot)
|
||||
// - Columns inside function calls or expressions (inside parentheses)
|
||||
// - Columns inside subqueries
|
||||
// - Columns that don't exist in the table (validation via model registry)
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
|
||||
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
|
||||
// - "users.status = 'active'" -> unchanged (already has prefix)
|
||||
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
|
||||
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause to process
|
||||
// - tableName: The table name to use as prefix
|
||||
//
|
||||
// Returns:
|
||||
// - The WHERE clause with table prefixes added to appropriate and valid columns
|
||||
func AddTablePrefixToColumns(where string, tableName string) string {
|
||||
if where == "" || tableName == "" {
|
||||
return where
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Get valid columns from the model registry for validation
|
||||
validColumns := getValidColumnsForTable(tableName)
|
||||
|
||||
// Split by AND to handle multiple conditions (parenthesis-aware)
|
||||
conditions := splitByAND(where)
|
||||
prefixedConditions := make([]string, 0, len(conditions))
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process this condition to add table prefix if appropriate
|
||||
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
|
||||
prefixedConditions = append(prefixedConditions, processedCond)
|
||||
}
|
||||
|
||||
if len(prefixedConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.Join(prefixedConditions, " AND ")
|
||||
}
|
||||
|
||||
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
|
||||
// Returns the condition unchanged if:
|
||||
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
|
||||
// - The column reference is inside a function call
|
||||
// - The column already has a table prefix
|
||||
// - No valid column reference is found
|
||||
// - The column doesn't exist in the table (when validColumns is provided)
|
||||
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
|
||||
// Strip one level of outer grouping parentheses to get to the actual condition
|
||||
strippedCond := stripOneOuterParentheses(cond)
|
||||
|
||||
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
|
||||
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
|
||||
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
|
||||
return cond
|
||||
}
|
||||
|
||||
// After stripping outer parentheses, check if there are multiple AND-separated conditions
|
||||
// at the top level. If so, split and process each separately to avoid incorrectly
|
||||
// treating "true AND status" as a single column name.
|
||||
subConditions := splitByAND(strippedCond)
|
||||
if len(subConditions) > 1 {
|
||||
// Multiple conditions found - process each separately
|
||||
logger.Debug("Found %d sub-conditions after stripping parentheses, processing separately", len(subConditions))
|
||||
processedConditions := make([]string, 0, len(subConditions))
|
||||
for _, subCond := range subConditions {
|
||||
// Recursively process each sub-condition
|
||||
processed := addPrefixToSingleCondition(subCond, tableName, validColumns)
|
||||
processedConditions = append(processedConditions, processed)
|
||||
}
|
||||
result := strings.Join(processedConditions, " AND ")
|
||||
// Preserve original outer parentheses if they existed
|
||||
if cond != strippedCond {
|
||||
result = "(" + result + ")"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// If we stripped parentheses and still have more parentheses, recursively process
|
||||
if cond != strippedCond && strings.HasPrefix(strippedCond, "(") && strings.HasSuffix(strippedCond, ")") {
|
||||
// Recursively handle nested parentheses
|
||||
processed := addPrefixToSingleCondition(strippedCond, tableName, validColumns)
|
||||
return "(" + processed + ")"
|
||||
}
|
||||
|
||||
// Extract the left side of the comparison (before the operator)
|
||||
columnRef := extractLeftSideOfComparison(strippedCond)
|
||||
if columnRef == "" {
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it already has a prefix (contains a dot)
|
||||
if strings.Contains(columnRef, ".") {
|
||||
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Skip if it's a function call or expression (contains parentheses)
|
||||
if strings.Contains(columnRef, "(") {
|
||||
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
|
||||
return cond
|
||||
}
|
||||
|
||||
// Validate that the column exists in the table (if we have column info)
|
||||
if !isValidColumn(columnRef, validColumns) {
|
||||
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
|
||||
return cond
|
||||
}
|
||||
|
||||
// It's a simple unqualified column reference that exists in the table - add the table prefix
|
||||
newRef := tableName + "." + columnRef
|
||||
result := qualifyColumnInCondition(cond, columnRef, newRef)
|
||||
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
|
||||
// This is used to identify the column reference that may need a table prefix.
|
||||
//
|
||||
// Examples:
|
||||
// - "status = 'active'" returns "status"
|
||||
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
|
||||
// - "priority > 5" returns "priority"
|
||||
//
|
||||
// Returns empty string if no operator is found.
|
||||
func extractLeftSideOfComparison(cond string) string {
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||
|
||||
// Find the first operator outside of parentheses and quotes
|
||||
minIdx := -1
|
||||
for _, op := range operators {
|
||||
idx := findOperatorOutsideParentheses(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if minIdx > 0 {
|
||||
leftSide := strings.TrimSpace(cond[:minIdx])
|
||||
// Remove any surrounding quotes
|
||||
leftSide = strings.Trim(leftSide, "`\"'")
|
||||
return leftSide
|
||||
}
|
||||
|
||||
// No operator found - might be a boolean column
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef := strings.Trim(parts[0], "`\"'")
|
||||
// Make sure it's not a SQL keyword
|
||||
if !IsSQLKeyword(strings.ToLower(columnRef)) {
|
||||
return columnRef
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "valid condition with parentheses - no prefix added",
|
||||
name: "valid condition with parentheses - prefix added to prevent ambiguity",
|
||||
where: "(status = 'active')",
|
||||
tableName: "users",
|
||||
expected: "status = 'active'",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mixed trivial and valid conditions - no prefix added",
|
||||
name: "mixed trivial and valid conditions - prefix added",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "users",
|
||||
expected: "status = 'active'",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "condition with correct table prefix - unchanged",
|
||||
@@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions without prefix - no prefix added",
|
||||
name: "multiple valid conditions without prefix - prefixes added",
|
||||
where: "status = 'active' AND age > 18",
|
||||
tableName: "users",
|
||||
expected: "status = 'active' AND age > 18",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "no table name provided",
|
||||
@@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
name: "mixed case AND operators",
|
||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||
tableName: "users",
|
||||
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
||||
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
|
||||
},
|
||||
{
|
||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
tableName: "users",
|
||||
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||
},
|
||||
{
|
||||
name: "dangerous DELETE keyword - blocked",
|
||||
@@ -138,7 +138,10 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
// Then sanitize the where clause
|
||||
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
@@ -348,6 +351,7 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||
tableName string
|
||||
options *RequestOptions
|
||||
expected string
|
||||
addPrefix bool
|
||||
}{
|
||||
{
|
||||
name: "preload relation prefix is preserved",
|
||||
@@ -416,15 +420,30 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||
options: &RequestOptions{Preload: []PreloadOption{}},
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
|
||||
{
|
||||
name: "complex where clause with subquery and preload",
|
||||
where: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (rid_parentmastertaskitem is null)`,
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: `("mastertaskitem"."rid_mastertask" IN (6, 173, 157, 172, 174, 171, 170, 169, 167, 168, 166, 145, 161, 164, 146, 160, 147, 159, 148, 150, 152, 175, 151, 8, 153, 149, 155, 154, 165)) AND (mastertaskitem.rid_parentmastertaskitem is null)`,
|
||||
addPrefix: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result string
|
||||
prefixedWhere := tt.where
|
||||
if tt.addPrefix {
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere = AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
}
|
||||
// Then sanitize the where clause
|
||||
if tt.options != nil {
|
||||
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||
result = SanitizeWhereClause(prefixedWhere, tt.tableName, tt.options)
|
||||
} else {
|
||||
result = SanitizeWhereClause(tt.where, tt.tableName)
|
||||
result = SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
@@ -639,3 +658,76 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Parentheses with true AND condition - should not prefix true",
|
||||
where: "(true AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Parentheses with multiple conditions including true",
|
||||
where: "(true AND status = 'active' AND id > 5)",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND mastertask.status = 'active' AND mastertask.id > 5)",
|
||||
},
|
||||
{
|
||||
name: "Nested parentheses with true",
|
||||
where: "((true AND status = 'active'))",
|
||||
tableName: "mastertask",
|
||||
expected: "((true AND mastertask.status = 'active'))",
|
||||
},
|
||||
{
|
||||
name: "Mixed: false AND valid conditions",
|
||||
where: "(false AND name = 'test')",
|
||||
tableName: "mastertask",
|
||||
expected: "(false AND mastertask.name = 'test')",
|
||||
},
|
||||
{
|
||||
name: "Mixed: null AND valid conditions",
|
||||
where: "(null AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(null AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Multiple true conditions in parentheses",
|
||||
where: "(true AND true AND status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "(true AND true AND mastertask.status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "Simple true without parens - should not prefix",
|
||||
where: "true",
|
||||
tableName: "mastertask",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Simple condition without parens - should prefix",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "Unregistered table with true - should not prefix true",
|
||||
where: "(true AND status = 'active')",
|
||||
tableName: "unregistered_table",
|
||||
expected: "(true AND unregistered_table.status = 'active')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
@@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
|
||||
// Filter preload sort columns
|
||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||
for _, sort := range preload.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column)
|
||||
}
|
||||
}
|
||||
filteredPreload.Sort = validPreloadSorts
|
||||
|
||||
validPreloads = append(validPreloads, filteredPreload)
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
@@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
return filtered
|
||||
}
|
||||
|
||||
// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe
|
||||
// and doesn't contain SQL injection attempts or dangerous commands
|
||||
func IsSafeSortExpression(expr string) bool {
|
||||
if expr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression must be enclosed in brackets
|
||||
expr = strings.TrimSpace(expr)
|
||||
if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove outer brackets for content validation
|
||||
expr = expr[1 : len(expr)-1]
|
||||
expr = strings.TrimSpace(expr)
|
||||
|
||||
// Convert to lowercase for checking dangerous keywords
|
||||
exprLower := strings.ToLower(expr)
|
||||
|
||||
// Check for dangerous SQL commands that should never be in a sort expression
|
||||
dangerousKeywords := []string{
|
||||
"drop ", "delete ", "insert ", "update ", "alter ", "create ",
|
||||
"truncate ", "exec ", "execute ", "grant ", "revoke ",
|
||||
"into ", "values ", "set ", "shutdown", "xp_",
|
||||
}
|
||||
|
||||
for _, keyword := range dangerousKeywords {
|
||||
if strings.Contains(exprLower, keyword) {
|
||||
logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check for SQL comment attempts
|
||||
if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") {
|
||||
logger.Warn("SQL comment detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for semicolon (command separator)
|
||||
if strings.Contains(expr, ";") {
|
||||
logger.Warn("Command separator (;) detected in sort expression: %s", expr)
|
||||
return false
|
||||
}
|
||||
|
||||
// Expression appears safe
|
||||
return true
|
||||
}
|
||||
|
||||
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||
func (v *ColumnValidator) GetValidColumns() []string {
|
||||
columns := make([]string, 0, len(v.validColumns))
|
||||
|
||||
@@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) {
|
||||
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeSortExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
shouldPass bool
|
||||
}{
|
||||
// Safe expressions
|
||||
{"Valid subquery", "(SELECT MAX(price) FROM products)", true},
|
||||
{"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true},
|
||||
{"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true},
|
||||
{"Valid function", "(COALESCE(discount, 0))", true},
|
||||
|
||||
// Dangerous expressions - SQL injection attempts
|
||||
{"DROP TABLE attempt", "(id); DROP TABLE users; --", false},
|
||||
{"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false},
|
||||
{"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false},
|
||||
{"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false},
|
||||
{"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false},
|
||||
{"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false},
|
||||
|
||||
// Comment injection
|
||||
{"SQL comment dash", "(id) -- malicious comment", false},
|
||||
{"SQL comment block start", "(id) /* comment", false},
|
||||
{"SQL comment block end", "(id) comment */", false},
|
||||
|
||||
// Semicolon attempts
|
||||
{"Semicolon separator", "(id); SELECT * FROM passwords", false},
|
||||
|
||||
// Empty/invalid
|
||||
{"Empty string", "", false},
|
||||
{"Just brackets", "()", true}, // Empty but technically valid structure
|
||||
{"No brackets", "id", false}, // Must have brackets for expressions
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsSafeSortExpression(tt.expression)
|
||||
if result != tt.shouldPass {
|
||||
t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Sort: []SortOption{
|
||||
{Column: "id", Direction: "ASC"}, // Valid column
|
||||
{Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression
|
||||
{Column: "name", Direction: "ASC"}, // Valid column
|
||||
{Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression
|
||||
{Column: "invalid_col", Direction: "ASC"}, // Invalid column
|
||||
{Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Should keep: id, safe expression, name, another safe expression
|
||||
// Should remove: dangerous expression, invalid column
|
||||
expectedCount := 4
|
||||
if len(filtered.Sort) != expectedCount {
|
||||
t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort))
|
||||
}
|
||||
|
||||
// Verify the kept options
|
||||
if filtered.Sort[0].Column != "id" {
|
||||
t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column)
|
||||
}
|
||||
if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" {
|
||||
t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column)
|
||||
}
|
||||
if filtered.Sort[2].Column != "name" {
|
||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ type Config struct {
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||
}
|
||||
|
||||
// ServerConfig holds server-related configuration
|
||||
@@ -91,3 +92,52 @@ type ErrorTrackingConfig struct {
|
||||
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
||||
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
||||
}
|
||||
|
||||
// EventBrokerConfig contains configuration for the event broker
|
||||
type EventBrokerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Provider string `mapstructure:"provider"` // memory, redis, nats, database
|
||||
Mode string `mapstructure:"mode"` // sync, async
|
||||
WorkerCount int `mapstructure:"worker_count"`
|
||||
BufferSize int `mapstructure:"buffer_size"`
|
||||
InstanceID string `mapstructure:"instance_id"`
|
||||
Redis EventBrokerRedisConfig `mapstructure:"redis"`
|
||||
NATS EventBrokerNATSConfig `mapstructure:"nats"`
|
||||
Database EventBrokerDatabaseConfig `mapstructure:"database"`
|
||||
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
|
||||
}
|
||||
|
||||
// EventBrokerRedisConfig contains Redis-specific configuration
|
||||
type EventBrokerRedisConfig struct {
|
||||
StreamName string `mapstructure:"stream_name"`
|
||||
ConsumerGroup string `mapstructure:"consumer_group"`
|
||||
MaxLen int64 `mapstructure:"max_len"`
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
// EventBrokerNATSConfig contains NATS-specific configuration
|
||||
type EventBrokerNATSConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
StreamName string `mapstructure:"stream_name"`
|
||||
Subjects []string `mapstructure:"subjects"`
|
||||
Storage string `mapstructure:"storage"` // file, memory
|
||||
MaxAge time.Duration `mapstructure:"max_age"`
|
||||
}
|
||||
|
||||
// EventBrokerDatabaseConfig contains database provider configuration
|
||||
type EventBrokerDatabaseConfig struct {
|
||||
TableName string `mapstructure:"table_name"`
|
||||
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
|
||||
PollInterval time.Duration `mapstructure:"poll_interval"`
|
||||
}
|
||||
|
||||
// EventBrokerRetryPolicyConfig contains retry policy configuration
|
||||
type EventBrokerRetryPolicyConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
InitialDelay time.Duration `mapstructure:"initial_delay"`
|
||||
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||
}
|
||||
|
||||
@@ -165,4 +165,39 @@ func setDefaults(v *viper.Viper) {
|
||||
|
||||
// Database defaults
|
||||
v.SetDefault("database.url", "")
|
||||
|
||||
// Event Broker defaults
|
||||
v.SetDefault("event_broker.enabled", false)
|
||||
v.SetDefault("event_broker.provider", "memory")
|
||||
v.SetDefault("event_broker.mode", "async")
|
||||
v.SetDefault("event_broker.worker_count", 10)
|
||||
v.SetDefault("event_broker.buffer_size", 1000)
|
||||
v.SetDefault("event_broker.instance_id", "")
|
||||
|
||||
// Event Broker - Redis defaults
|
||||
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
|
||||
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
|
||||
v.SetDefault("event_broker.redis.max_len", 10000)
|
||||
v.SetDefault("event_broker.redis.host", "localhost")
|
||||
v.SetDefault("event_broker.redis.port", 6379)
|
||||
v.SetDefault("event_broker.redis.password", "")
|
||||
v.SetDefault("event_broker.redis.db", 0)
|
||||
|
||||
// Event Broker - NATS defaults
|
||||
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
|
||||
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
|
||||
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
|
||||
v.SetDefault("event_broker.nats.storage", "file")
|
||||
v.SetDefault("event_broker.nats.max_age", "24h")
|
||||
|
||||
// Event Broker - Database defaults
|
||||
v.SetDefault("event_broker.database.table_name", "events")
|
||||
v.SetDefault("event_broker.database.channel", "resolvespec_events")
|
||||
v.SetDefault("event_broker.database.poll_interval", "1s")
|
||||
|
||||
// Event Broker - Retry Policy defaults
|
||||
v.SetDefault("event_broker.retry_policy.max_retries", 3)
|
||||
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||
}
|
||||
|
||||
@@ -90,12 +90,12 @@ Panics are automatically captured when using the logger's panic handlers:
|
||||
|
||||
```go
|
||||
// Using CatchPanic
|
||||
defer logger.CatchPanic("MyFunction")
|
||||
defer logger.CatchPanic("MyFunction")()
|
||||
|
||||
// Using CatchPanicCallback
|
||||
defer logger.CatchPanicCallback("MyFunction", func(err any) {
|
||||
// Custom cleanup
|
||||
})
|
||||
})()
|
||||
|
||||
// Using HandlePanic
|
||||
defer func() {
|
||||
|
||||
353
pkg/eventbroker/IMPLEMENTATION_PLAN.md
Normal file
353
pkg/eventbroker/IMPLEMENTATION_PLAN.md
Normal file
@@ -0,0 +1,353 @@
|
||||
# Event Broker System Implementation Plan
|
||||
|
||||
## Overview
|
||||
Implement a comprehensive event handler/broker system for ResolveSpec that follows existing architectural patterns (Provider interface, Hook system, Config management, Graceful shutdown).
|
||||
|
||||
## Requirements Met
|
||||
- ✅ Events with sources (database, websocket, frontend, system)
|
||||
- ✅ Event statuses (pending, processing, completed, failed)
|
||||
- ✅ Timestamps, JSON payloads, user IDs, session IDs
|
||||
- ✅ Program instance IDs for tracking server instances
|
||||
- ✅ Both sync and async processing modes
|
||||
- ✅ Multiple provider backends (in-memory, Redis, NATS, database)
|
||||
- ✅ Cross-instance pub/sub support
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
**Event Structure** (with full metadata):
|
||||
```go
|
||||
type Event struct {
|
||||
ID string // UUID
|
||||
Source EventSource // database, websocket, system, frontend
|
||||
Type string // Pattern: schema.entity.operation
|
||||
Status EventStatus // pending, processing, completed, failed
|
||||
Payload json.RawMessage // JSON payload
|
||||
UserID int
|
||||
SessionID string
|
||||
InstanceID string // Server instance identifier
|
||||
Schema string
|
||||
Entity string
|
||||
Operation string // create, update, delete, read
|
||||
CreatedAt time.Time
|
||||
ProcessedAt *time.Time
|
||||
CompletedAt *time.Time
|
||||
Error string
|
||||
Metadata map[string]interface{}
|
||||
RetryCount int
|
||||
}
|
||||
```
|
||||
|
||||
**Provider Pattern** (like cache.Provider):
|
||||
```go
|
||||
type Provider interface {
|
||||
Store(ctx context.Context, event *Event) error
|
||||
Get(ctx context.Context, id string) (*Event, error)
|
||||
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||
UpdateStatus(ctx context.Context, id string, status EventStatus, error string) error
|
||||
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
Close() error
|
||||
Stats(ctx context.Context) (*ProviderStats, error)
|
||||
}
|
||||
```
|
||||
|
||||
**Broker Interface**:
|
||||
```go
|
||||
type Broker interface {
|
||||
Publish(ctx context.Context, event *Event) error // Mode-dependent
|
||||
PublishSync(ctx context.Context, event *Event) error // Blocks
|
||||
PublishAsync(ctx context.Context, event *Event) error // Non-blocking
|
||||
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||
Unsubscribe(id SubscriptionID) error
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
Stats(ctx context.Context) (*BrokerStats, error)
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Phase 1: Core Foundation (Files: 1-5)
|
||||
|
||||
**1. Create `pkg/eventbroker/event.go`**
|
||||
- Event struct with all required fields (status, timestamps, user, instance ID, etc.)
|
||||
- EventSource enum (database, websocket, frontend, system, internal)
|
||||
- EventStatus enum (pending, processing, completed, failed)
|
||||
- Helper: `EventType(schema, entity, operation string) string`
|
||||
- Helper: `NewEvent()` constructor with UUID generation
|
||||
|
||||
**2. Create `pkg/eventbroker/provider.go`**
|
||||
- Provider interface definition
|
||||
- EventFilter struct for queries
|
||||
- ProviderStats struct
|
||||
|
||||
**3. Create `pkg/eventbroker/handler.go`**
|
||||
- EventHandler interface
|
||||
- EventHandlerFunc adapter type
|
||||
|
||||
**4. Create `pkg/eventbroker/broker.go`**
|
||||
- Broker interface definition
|
||||
- EventBroker struct implementation
|
||||
- ProcessingMode enum (sync, async)
|
||||
- Options struct with functional options (WithProvider, WithMode, WithWorkerCount, etc.)
|
||||
- NewBroker() constructor
|
||||
- Sync processing implementation
|
||||
|
||||
**5. Create `pkg/eventbroker/subscription.go`**
|
||||
- Pattern matching using glob syntax (e.g., "public.users.*", "*.*.create")
|
||||
- subscriptionManager struct
|
||||
- SubscriptionID type
|
||||
- Subscribe/Unsubscribe logic
|
||||
|
||||
### Phase 2: Configuration & Integration (Files: 6-8)
|
||||
|
||||
**6. Create `pkg/eventbroker/config.go`**
|
||||
- EventBrokerConfig struct
|
||||
- RedisConfig, NATSConfig, DatabaseConfig structs
|
||||
- RetryPolicyConfig struct
|
||||
|
||||
**7. Update `pkg/config/config.go`**
|
||||
- Add `EventBroker EventBrokerConfig` field to Config struct
|
||||
|
||||
**8. Update `pkg/config/manager.go`**
|
||||
- Add event broker defaults to `setDefaults()`:
|
||||
```go
|
||||
v.SetDefault("event_broker.enabled", false)
|
||||
v.SetDefault("event_broker.provider", "memory")
|
||||
v.SetDefault("event_broker.mode", "async")
|
||||
v.SetDefault("event_broker.worker_count", 10)
|
||||
v.SetDefault("event_broker.buffer_size", 1000)
|
||||
```
|
||||
|
||||
### Phase 3: Memory Provider (Files: 9)
|
||||
|
||||
**9. Create `pkg/eventbroker/provider_memory.go`**
|
||||
- MemoryProvider struct with mutex-protected map
|
||||
- In-memory event storage
|
||||
- Pattern matching for subscriptions
|
||||
- Channel-based streaming for real-time events
|
||||
- LRU eviction when max size reached
|
||||
- Cleanup goroutine for old completed events
|
||||
- **Note**: Single-instance only (no cross-instance pub/sub)
|
||||
|
||||
### Phase 4: Async Processing (Update File: 4)
|
||||
|
||||
**10. Update `pkg/eventbroker/broker.go`** (add async support)
|
||||
- workerPool struct with configurable worker count
|
||||
- Buffered channel for event queue
|
||||
- Worker goroutines that process events
|
||||
- PublishAsync() queues to channel
|
||||
- Graceful shutdown: stop accepting events, drain queue, wait for workers
|
||||
- Retry logic with exponential backoff
|
||||
|
||||
### Phase 5: Hook Integration (Files: 11)
|
||||
|
||||
**11. Create `pkg/eventbroker/hooks.go`**
|
||||
- `RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry)`
|
||||
- Registers AfterCreate, AfterUpdate, AfterDelete, AfterRead hooks
|
||||
- Extracts UserContext from hook context
|
||||
- Creates Event with proper metadata
|
||||
- Calls `broker.PublishAsync()` to not block CRUD operations
|
||||
|
||||
### Phase 6: Global Singleton & Factory (Files: 12-13)
|
||||
|
||||
**12. Create `pkg/eventbroker/eventbroker.go`**
|
||||
- Global `defaultBroker` variable
|
||||
- `Initialize(config *config.Config) error` - creates broker from config
|
||||
- `SetDefaultBroker(broker Broker)`
|
||||
- `GetDefaultBroker() Broker`
|
||||
- Helper functions: `Publish()`, `PublishAsync()`, `PublishSync()`, `Subscribe()`
|
||||
- `RegisterShutdown(broker Broker)` - registers with server.RegisterShutdownCallback()
|
||||
|
||||
**13. Create `pkg/eventbroker/factory.go`**
|
||||
- `NewProviderFromConfig(config EventBrokerConfig) (Provider, error)`
|
||||
- Provider selection logic (memory, redis, nats, database)
|
||||
- Returns appropriate provider based on config
|
||||
|
||||
### Phase 7: Redis Provider (Files: 14)
|
||||
|
||||
**14. Create `pkg/eventbroker/provider_redis.go`**
|
||||
- RedisProvider using Redis Streams (XADD, XREAD, XGROUP)
|
||||
- Consumer group for distributed processing
|
||||
- Cross-instance pub/sub support
|
||||
- Stream(pattern) subscribes to consumer group
|
||||
- Publish() uses XADD to append to stream
|
||||
- Graceful shutdown: acknowledge pending messages
|
||||
|
||||
**Dependencies**: `github.com/redis/go-redis/v9`
|
||||
|
||||
### Phase 8: NATS Provider (Files: 15)
|
||||
|
||||
**15. Create `pkg/eventbroker/provider_nats.go`**
|
||||
- NATSProvider using NATS JetStream
|
||||
- Subject-based routing: `events.{source}.{type}`
|
||||
- Wildcard subscriptions support
|
||||
- Durable consumers for replay
|
||||
- At-least-once delivery semantics
|
||||
|
||||
**Dependencies**: `github.com/nats-io/nats.go`
|
||||
|
||||
### Phase 9: Database Provider (Files: 16)
|
||||
|
||||
**16. Create `pkg/eventbroker/provider_database.go`**
|
||||
- DatabaseProvider using `common.Database` interface
|
||||
- Table schema creation (events table with indexes)
|
||||
- Polling-based event consumption (configurable interval)
|
||||
- Full SQL query support via List(filter)
|
||||
- Transaction support for atomic operations
|
||||
- Good for audit trails and debugging
|
||||
|
||||
### Phase 10: Metrics Integration (Files: 17)
|
||||
|
||||
**17. Create `pkg/eventbroker/metrics.go`**
|
||||
- Integrate with existing `metrics.Provider`
|
||||
- Record metrics:
|
||||
- `eventbroker_events_published_total{source, type}`
|
||||
- `eventbroker_events_processed_total{source, type, status}`
|
||||
- `eventbroker_event_processing_duration_seconds{source, type}`
|
||||
- `eventbroker_queue_size`
|
||||
- `eventbroker_workers_active`
|
||||
|
||||
**18. Update `pkg/metrics/interfaces.go`**
|
||||
- Add methods to Provider interface:
|
||||
```go
|
||||
RecordEventPublished(source, eventType string)
|
||||
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||
UpdateEventQueueSize(size int64)
|
||||
```
|
||||
|
||||
### Phase 11: Testing & Examples (Files: 19-20)
|
||||
|
||||
**19. Create `pkg/eventbroker/eventbroker_test.go`**
|
||||
- Unit tests for Event marshaling
|
||||
- Pattern matching tests
|
||||
- MemoryProvider tests
|
||||
- Sync vs async mode tests
|
||||
- Concurrent publish/subscribe tests
|
||||
- Graceful shutdown tests
|
||||
|
||||
**20. Create `pkg/eventbroker/example_usage.go`**
|
||||
- Basic publish example
|
||||
- Subscribe with patterns example
|
||||
- Hook integration example
|
||||
- Multiple handlers example
|
||||
- Error handling example
|
||||
|
||||
## Integration Points
|
||||
|
||||
### Hook System Integration
|
||||
```go
|
||||
// In application initialization (e.g., main.go)
|
||||
eventbroker.RegisterCRUDHooks(broker, handler.Hooks())
|
||||
```
|
||||
|
||||
This automatically publishes events for all CRUD operations:
|
||||
- `schema.entity.create` after inserts
|
||||
- `schema.entity.update` after updates
|
||||
- `schema.entity.delete` after deletes
|
||||
- `schema.entity.read` after reads
|
||||
|
||||
### Shutdown Integration
|
||||
```go
|
||||
// In application initialization
|
||||
eventbroker.RegisterShutdown(broker)
|
||||
```
|
||||
|
||||
Ensures event broker flushes pending events before shutdown.
|
||||
|
||||
### Configuration Example
|
||||
```yaml
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: redis # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Publishing Custom Events
|
||||
```go
|
||||
// WebSocket event
|
||||
event := &eventbroker.Event{
|
||||
Source: eventbroker.EventSourceWebSocket,
|
||||
Type: "chat.message",
|
||||
Payload: json.RawMessage(`{"room": "lobby", "msg": "Hello"}`),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
eventbroker.PublishAsync(ctx, event)
|
||||
```
|
||||
|
||||
### Subscribing to Events
|
||||
```go
|
||||
// Subscribe to all user creation events
|
||||
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("New user created: %s", event.Payload)
|
||||
// Send welcome email, update cache, etc.
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Subscribe to all events from database
|
||||
eventbroker.Subscribe("*", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
if event.Source == eventbroker.EventSourceDatabase {
|
||||
// Audit logging
|
||||
}
|
||||
return nil
|
||||
},
|
||||
))
|
||||
```
|
||||
|
||||
## Critical Files Reference
|
||||
|
||||
**Patterns to Follow**:
|
||||
- `pkg/cache/provider.go` - Provider interface pattern
|
||||
- `pkg/restheadspec/hooks.go` - Hook system integration
|
||||
- `pkg/config/manager.go` - Configuration pattern
|
||||
- `pkg/server/shutdown.go` - Shutdown callbacks
|
||||
|
||||
**Files to Modify**:
|
||||
- `pkg/config/config.go` - Add EventBroker field
|
||||
- `pkg/config/manager.go` - Add defaults
|
||||
- `pkg/metrics/interfaces.go` - Add event broker metrics
|
||||
|
||||
**New Package**:
|
||||
- `pkg/eventbroker/` (20 files total)
|
||||
|
||||
## Provider Feature Matrix
|
||||
|
||||
| Feature | Memory | Redis | NATS | Database |
|
||||
|---------|--------|-------|------|----------|
|
||||
| Persistence | ❌ | ✅ | ✅ | ✅ |
|
||||
| Cross-instance | ❌ | ✅ | ✅ | ✅ |
|
||||
| Real-time | ✅ | ✅ | ✅ | ⚠️ (polling) |
|
||||
| Query history | Limited | Limited | ✅ (replay) | ✅ (SQL) |
|
||||
| External deps | None | Redis | NATS | None |
|
||||
| Complexity | Low | Medium | Medium | Low |
|
||||
|
||||
## Implementation Order Priority
|
||||
|
||||
1. **Core + Memory Provider** (Phase 1-3) - Functional in-process event system
|
||||
2. **Async + Hooks** (Phase 4-5) - Non-blocking event dispatch integrated with CRUD
|
||||
3. **Config + Singleton** (Phase 6) - Easy initialization and usage
|
||||
4. **Redis Provider** (Phase 7) - Production-ready distributed events
|
||||
5. **Metrics** (Phase 10) - Observability
|
||||
6. **NATS/Database** (Phase 8-9) - Alternative backends
|
||||
7. **Tests + Examples** (Phase 11) - Documentation and reliability
|
||||
|
||||
## Next Steps
|
||||
|
||||
After approval, implement in order of phases. Each phase builds on previous phases and can be tested independently.
|
||||
347
pkg/eventbroker/README.md
Normal file
347
pkg/eventbroker/README.md
Normal file
@@ -0,0 +1,347 @@
|
||||
# Event Broker System
|
||||
|
||||
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
|
||||
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
|
||||
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
|
||||
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
|
||||
- **Pattern Matching**: Subscribe to events using glob-style patterns
|
||||
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
|
||||
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
|
||||
- **Retry Logic**: Configurable retry policy with exponential backoff
|
||||
- **Metrics**: Prometheus-compatible metrics for monitoring
|
||||
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configuration
|
||||
|
||||
Add to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: memory # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
```
|
||||
|
||||
### 2. Initialize
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
cfg, _ := cfgMgr.GetConfig()
|
||||
|
||||
// Initialize event broker
|
||||
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Subscribe to Events
|
||||
|
||||
```go
|
||||
// Subscribe to specific events
|
||||
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("New user created: %s", event.Payload)
|
||||
// Send welcome email, update cache, etc.
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Subscribe with patterns
|
||||
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
```
|
||||
|
||||
### 4. Publish Events
|
||||
|
||||
```go
|
||||
// Create and publish an event
|
||||
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
|
||||
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
|
||||
event.UserID = 123
|
||||
event.SessionID = "session-456"
|
||||
event.Schema = "public"
|
||||
event.Entity = "users"
|
||||
event.Operation = "update"
|
||||
|
||||
event.SetPayload(map[string]interface{}{
|
||||
"id": 123,
|
||||
"name": "John Doe",
|
||||
})
|
||||
|
||||
// Async (non-blocking)
|
||||
eventbroker.PublishAsync(ctx, event)
|
||||
|
||||
// Sync (blocking)
|
||||
eventbroker.PublishSync(ctx, event)
|
||||
```
|
||||
|
||||
## Automatic CRUD Event Capture
|
||||
|
||||
Automatically capture database CRUD operations:
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
)
|
||||
|
||||
func setupHooks(handler *restheadspec.Handler) {
|
||||
broker := eventbroker.GetDefaultBroker()
|
||||
|
||||
// Configure which operations to capture
|
||||
config := eventbroker.DefaultCRUDHookConfig()
|
||||
config.EnableRead = false // Disable read events for performance
|
||||
|
||||
// Register hooks
|
||||
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
|
||||
|
||||
// Now all create/update/delete operations automatically publish events!
|
||||
}
|
||||
```
|
||||
|
||||
## Event Structure
|
||||
|
||||
Every event contains:
|
||||
|
||||
```go
|
||||
type Event struct {
|
||||
ID string // UUID
|
||||
Source EventSource // database, websocket, system, frontend, internal
|
||||
Type string // Pattern: schema.entity.operation
|
||||
Status EventStatus // pending, processing, completed, failed
|
||||
Payload json.RawMessage // JSON payload
|
||||
UserID int // User who triggered the event
|
||||
SessionID string // Session identifier
|
||||
InstanceID string // Server instance identifier
|
||||
Schema string // Database schema
|
||||
Entity string // Database entity/table
|
||||
Operation string // create, update, delete, read
|
||||
CreatedAt time.Time // When event was created
|
||||
ProcessedAt *time.Time // When processing started
|
||||
CompletedAt *time.Time // When processing completed
|
||||
Error string // Error message if failed
|
||||
Metadata map[string]interface{} // Additional context
|
||||
RetryCount int // Number of retry attempts
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern Matching
|
||||
|
||||
Subscribe to events using glob-style patterns:
|
||||
|
||||
| Pattern | Matches | Example |
|
||||
|---------|---------|---------|
|
||||
| `*` | All events | Any event |
|
||||
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
|
||||
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
|
||||
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
|
||||
| `public.users.create` | Exact match | Only `public.users.create` |
|
||||
|
||||
## Providers
|
||||
|
||||
### Memory Provider (Default)
|
||||
|
||||
Best for: Development, single-instance deployments
|
||||
|
||||
- **Pros**: Fast, no dependencies, simple
|
||||
- **Cons**: Events lost on restart, single-instance only
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: memory
|
||||
```
|
||||
|
||||
### Redis Provider
|
||||
|
||||
Best for: Production, multi-instance deployments
|
||||
|
||||
- **Pros**: Persistent, cross-instance pub/sub, reliable, Redis Streams support
|
||||
- **Cons**: Requires Redis server
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: redis
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
max_len: 10000
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
```
|
||||
|
||||
### NATS Provider
|
||||
|
||||
Best for: High-performance, low-latency requirements
|
||||
|
||||
- **Pros**: Very fast, built-in clustering, durable, JetStream support
|
||||
- **Cons**: Requires NATS server
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: nats
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "RESOLVESPEC_EVENTS"
|
||||
storage: "file" # or "memory"
|
||||
max_age: "24h"
|
||||
```
|
||||
|
||||
### Database Provider
|
||||
|
||||
Best for: Audit trails, event replay, SQL queries
|
||||
|
||||
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
|
||||
- **Cons**: Slower than Redis/NATS, requires database connection
|
||||
- **Status**: ✅ Implemented
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
provider: database
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "resolvespec_events"
|
||||
poll_interval: "1s"
|
||||
```
|
||||
|
||||
## Processing Modes
|
||||
|
||||
### Async Mode (Recommended)
|
||||
|
||||
Events are queued and processed by worker pool:
|
||||
|
||||
- Non-blocking event publishing
|
||||
- Configurable worker count
|
||||
- Better throughput
|
||||
- Events may be processed out of order
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
mode: async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
```
|
||||
|
||||
### Sync Mode
|
||||
|
||||
Events are processed immediately:
|
||||
|
||||
- Blocking event publishing
|
||||
- Guaranteed ordering
|
||||
- Immediate error feedback
|
||||
- Lower throughput
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
mode: sync
|
||||
```
|
||||
|
||||
## Retry Policy
|
||||
|
||||
Configure automatic retries for failed handlers:
|
||||
|
||||
```yaml
|
||||
event_broker:
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 30s
|
||||
backoff_factor: 2.0 # Exponential backoff
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
The event broker exposes Prometheus metrics:
|
||||
|
||||
- `eventbroker_events_published_total{source, type}` - Total events published
|
||||
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
|
||||
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
|
||||
- `eventbroker_queue_size` - Current queue size (async mode)
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Async Mode**: For better performance, use async mode in production
|
||||
2. **Disable Read Events**: Read events can be high volume; disable if not needed
|
||||
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
|
||||
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
|
||||
5. **Idempotency**: Make handlers idempotent as events may be retried
|
||||
6. **Payload Size**: Keep payloads reasonable; avoid large objects
|
||||
7. **Monitoring**: Monitor metrics to detect issues early
|
||||
|
||||
## Examples
|
||||
|
||||
See `example_usage.go` for comprehensive examples including:
|
||||
- Basic event publishing and subscription
|
||||
- Hook integration
|
||||
- Error handling
|
||||
- Configuration
|
||||
- Pattern matching
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ Application │
|
||||
└────────┬────────┘
|
||||
│
|
||||
├─ Publish Events
|
||||
│
|
||||
┌────────▼────────┐ ┌──────────────┐
|
||||
│ Event Broker │◄────►│ Subscribers │
|
||||
└────────┬────────┘ └──────────────┘
|
||||
│
|
||||
├─ Store Events
|
||||
│
|
||||
┌────────▼────────┐
|
||||
│ Provider │
|
||||
│ (Memory/Redis │
|
||||
│ /NATS/DB) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## Implemented Features
|
||||
|
||||
- [x] Memory Provider (in-process, single-instance)
|
||||
- [x] Redis Streams Provider (distributed, persistent)
|
||||
- [x] NATS JetStream Provider (distributed, high-performance)
|
||||
- [x] Database Provider with PostgreSQL NOTIFY (SQL-queryable, audit-friendly)
|
||||
- [x] Sync and Async processing modes
|
||||
- [x] Pattern-based subscriptions
|
||||
- [x] Hook integration for automatic CRUD events
|
||||
- [x] Retry policy with exponential backoff
|
||||
- [x] Graceful shutdown
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Event replay functionality from specific timestamp
|
||||
- [ ] Dead letter queue for failed events
|
||||
- [ ] Event filtering at provider level for performance
|
||||
- [ ] Batch publishing support
|
||||
- [ ] Event compression for large payloads
|
||||
- [ ] Schema versioning and migration
|
||||
- [ ] Event streaming to external systems (Kafka, RabbitMQ)
|
||||
- [ ] Event aggregation and analytics
|
||||
453
pkg/eventbroker/broker.go
Normal file
453
pkg/eventbroker/broker.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Broker is the main interface for event publishing and subscription
|
||||
type Broker interface {
|
||||
// Publish publishes an event (mode-dependent: sync or async)
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
|
||||
// PublishSync publishes an event synchronously (blocks until all handlers complete)
|
||||
PublishSync(ctx context.Context, event *Event) error
|
||||
|
||||
// PublishAsync publishes an event asynchronously (returns immediately)
|
||||
PublishAsync(ctx context.Context, event *Event) error
|
||||
|
||||
// Subscribe registers a handler for events matching the pattern
|
||||
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
Unsubscribe(id SubscriptionID) error
|
||||
|
||||
// Start starts the broker (begins processing events)
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop stops the broker gracefully (flushes pending events)
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Stats returns broker statistics
|
||||
Stats(ctx context.Context) (*BrokerStats, error)
|
||||
|
||||
// InstanceID returns the instance ID of this broker
|
||||
InstanceID() string
|
||||
}
|
||||
|
||||
// ProcessingMode determines how events are processed
|
||||
type ProcessingMode string
|
||||
|
||||
const (
|
||||
ProcessingModeSync ProcessingMode = "sync"
|
||||
ProcessingModeAsync ProcessingMode = "async"
|
||||
)
|
||||
|
||||
// BrokerStats contains broker statistics
|
||||
type BrokerStats struct {
|
||||
InstanceID string `json:"instance_id"`
|
||||
Mode ProcessingMode `json:"mode"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
TotalPublished int64 `json:"total_published"`
|
||||
TotalProcessed int64 `json:"total_processed"`
|
||||
TotalFailed int64 `json:"total_failed"`
|
||||
ActiveSubscribers int `json:"active_subscribers"`
|
||||
QueueSize int `json:"queue_size,omitempty"` // For async mode
|
||||
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
|
||||
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
|
||||
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroker implements the Broker interface
|
||||
type EventBroker struct {
|
||||
provider Provider
|
||||
subscriptions *subscriptionManager
|
||||
mode ProcessingMode
|
||||
instanceID string
|
||||
retryPolicy *RetryPolicy
|
||||
|
||||
// Async mode fields (initialized in Phase 4)
|
||||
workerPool *workerPool
|
||||
|
||||
// Runtime state
|
||||
isRunning atomic.Bool
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Statistics
|
||||
statsPublished atomic.Int64
|
||||
statsProcessed atomic.Int64
|
||||
statsFailed atomic.Int64
|
||||
}
|
||||
|
||||
// RetryPolicy defines how failed events should be retried
|
||||
type RetryPolicy struct {
|
||||
MaxRetries int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
BackoffFactor float64
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy returns a sensible default retry policy
|
||||
func DefaultRetryPolicy() *RetryPolicy {
|
||||
return &RetryPolicy{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
// Options for creating a new broker
|
||||
type Options struct {
|
||||
Provider Provider
|
||||
Mode ProcessingMode
|
||||
WorkerCount int // For async mode
|
||||
BufferSize int // For async mode
|
||||
RetryPolicy *RetryPolicy
|
||||
InstanceID string
|
||||
}
|
||||
|
||||
// NewBroker creates a new event broker with the given options
|
||||
func NewBroker(opts Options) (*EventBroker, error) {
|
||||
if opts.Provider == nil {
|
||||
return nil, fmt.Errorf("provider is required")
|
||||
}
|
||||
if opts.InstanceID == "" {
|
||||
return nil, fmt.Errorf("instance ID is required")
|
||||
}
|
||||
if opts.Mode == "" {
|
||||
opts.Mode = ProcessingModeAsync // Default to async
|
||||
}
|
||||
if opts.RetryPolicy == nil {
|
||||
opts.RetryPolicy = DefaultRetryPolicy()
|
||||
}
|
||||
|
||||
broker := &EventBroker{
|
||||
provider: opts.Provider,
|
||||
subscriptions: newSubscriptionManager(),
|
||||
mode: opts.Mode,
|
||||
instanceID: opts.InstanceID,
|
||||
retryPolicy: opts.RetryPolicy,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Worker pool will be initialized in Phase 4 for async mode
|
||||
if opts.Mode == ProcessingModeAsync {
|
||||
if opts.WorkerCount == 0 {
|
||||
opts.WorkerCount = 10 // Default
|
||||
}
|
||||
if opts.BufferSize == 0 {
|
||||
opts.BufferSize = 1000 // Default
|
||||
}
|
||||
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
|
||||
}
|
||||
|
||||
return broker, nil
|
||||
}
|
||||
|
||||
// Functional option pattern helpers
|
||||
func WithProvider(p Provider) func(*Options) {
|
||||
return func(o *Options) { o.Provider = p }
|
||||
}
|
||||
|
||||
func WithMode(m ProcessingMode) func(*Options) {
|
||||
return func(o *Options) { o.Mode = m }
|
||||
}
|
||||
|
||||
func WithWorkerCount(count int) func(*Options) {
|
||||
return func(o *Options) { o.WorkerCount = count }
|
||||
}
|
||||
|
||||
func WithBufferSize(size int) func(*Options) {
|
||||
return func(o *Options) { o.BufferSize = size }
|
||||
}
|
||||
|
||||
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
|
||||
return func(o *Options) { o.RetryPolicy = policy }
|
||||
}
|
||||
|
||||
func WithInstanceID(id string) func(*Options) {
|
||||
return func(o *Options) { o.InstanceID = id }
|
||||
}
|
||||
|
||||
// Start starts the broker
|
||||
func (b *EventBroker) Start(ctx context.Context) error {
|
||||
if b.isRunning.Load() {
|
||||
return fmt.Errorf("broker already running")
|
||||
}
|
||||
|
||||
b.isRunning.Store(true)
|
||||
|
||||
// Start worker pool for async mode
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
b.workerPool.Start()
|
||||
}
|
||||
|
||||
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the broker gracefully
|
||||
func (b *EventBroker) Stop(ctx context.Context) error {
|
||||
var stopErr error
|
||||
|
||||
b.stopOnce.Do(func() {
|
||||
logger.Info("Stopping event broker...")
|
||||
|
||||
// Mark as not running
|
||||
b.isRunning.Store(false)
|
||||
|
||||
// Close the stop channel
|
||||
close(b.stopCh)
|
||||
|
||||
// Stop worker pool for async mode
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
if err := b.workerPool.Stop(ctx); err != nil {
|
||||
logger.Error("Error stopping worker pool: %v", err)
|
||||
stopErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
b.wg.Wait()
|
||||
|
||||
// Close provider
|
||||
if err := b.provider.Close(); err != nil {
|
||||
logger.Error("Error closing provider: %v", err)
|
||||
if stopErr == nil {
|
||||
stopErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Event broker stopped")
|
||||
})
|
||||
|
||||
return stopErr
|
||||
}
|
||||
|
||||
// Publish publishes an event based on the broker's mode
|
||||
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
|
||||
if b.mode == ProcessingModeSync {
|
||||
return b.PublishSync(ctx, event)
|
||||
}
|
||||
return b.PublishAsync(ctx, event)
|
||||
}
|
||||
|
||||
// PublishSync publishes an event synchronously
|
||||
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
|
||||
if !b.isRunning.Load() {
|
||||
return fmt.Errorf("broker is not running")
|
||||
}
|
||||
|
||||
// Validate event
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid event: %w", err)
|
||||
}
|
||||
|
||||
// Store event in provider
|
||||
if err := b.provider.Publish(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
b.statsPublished.Add(1)
|
||||
|
||||
// Record metrics
|
||||
recordEventPublished(event)
|
||||
|
||||
// Process event synchronously
|
||||
if err := b.processEvent(ctx, event); err != nil {
|
||||
logger.Error("Failed to process event %s: %v", event.ID, err)
|
||||
b.statsFailed.Add(1)
|
||||
return err
|
||||
}
|
||||
|
||||
b.statsProcessed.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishAsync publishes an event asynchronously
|
||||
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
|
||||
if !b.isRunning.Load() {
|
||||
return fmt.Errorf("broker is not running")
|
||||
}
|
||||
|
||||
// Validate event
|
||||
if err := event.Validate(); err != nil {
|
||||
return fmt.Errorf("invalid event: %w", err)
|
||||
}
|
||||
|
||||
// Store event in provider
|
||||
if err := b.provider.Publish(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
b.statsPublished.Add(1)
|
||||
|
||||
// Record metrics
|
||||
recordEventPublished(event)
|
||||
|
||||
// Queue for async processing
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
// Update queue size metrics
|
||||
updateQueueSize(int64(b.workerPool.QueueSize()))
|
||||
return b.workerPool.Submit(ctx, event)
|
||||
}
|
||||
|
||||
// Fallback to sync if async not configured
|
||||
return b.processEvent(ctx, event)
|
||||
}
|
||||
|
||||
// Subscribe adds a subscription for events matching the pattern
|
||||
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
return b.subscriptions.Subscribe(pattern, handler)
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
|
||||
return b.subscriptions.Unsubscribe(id)
|
||||
}
|
||||
|
||||
// processEvent processes an event by calling all matching handlers
|
||||
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// Get all handlers matching this event type
|
||||
handlers := b.subscriptions.GetMatching(event.Type)
|
||||
|
||||
if len(handlers) == 0 {
|
||||
logger.Debug("No handlers for event type: %s", event.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
|
||||
|
||||
// Mark event as processing
|
||||
event.MarkProcessing()
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
var lastErr error
|
||||
for i, handler := range handlers {
|
||||
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
|
||||
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
|
||||
lastErr = err
|
||||
// Continue processing other handlers
|
||||
}
|
||||
}
|
||||
|
||||
// Update final status
|
||||
if lastErr != nil {
|
||||
event.MarkFailed(lastErr)
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
recordEventProcessed(event, time.Since(startTime))
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
event.MarkCompleted()
|
||||
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
|
||||
logger.Warn("Failed to update event status: %v", err)
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
recordEventProcessed(event, time.Since(startTime))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeHandlerWithRetry executes a handler with retry logic
|
||||
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Calculate backoff delay
|
||||
delay := b.calculateBackoff(attempt)
|
||||
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
|
||||
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
event.IncrementRetry()
|
||||
}
|
||||
|
||||
// Execute handler
|
||||
if err := handler.Handle(ctx, event); err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Success
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
|
||||
}
|
||||
|
||||
// calculateBackoff calculates the backoff delay for a retry attempt
|
||||
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
|
||||
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
|
||||
if delay > float64(b.retryPolicy.MaxDelay) {
|
||||
delay = float64(b.retryPolicy.MaxDelay)
|
||||
}
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// pow is a simple integer power function
|
||||
func pow(base float64, exp float64) float64 {
|
||||
result := 1.0
|
||||
for i := 0.0; i < exp; i++ {
|
||||
result *= base
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Stats returns broker statistics
|
||||
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
|
||||
providerStats, err := b.provider.Stats(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get provider stats: %v", err)
|
||||
}
|
||||
|
||||
stats := &BrokerStats{
|
||||
InstanceID: b.instanceID,
|
||||
Mode: b.mode,
|
||||
IsRunning: b.isRunning.Load(),
|
||||
TotalPublished: b.statsPublished.Load(),
|
||||
TotalProcessed: b.statsProcessed.Load(),
|
||||
TotalFailed: b.statsFailed.Load(),
|
||||
ActiveSubscribers: b.subscriptions.Count(),
|
||||
ProviderStats: providerStats,
|
||||
}
|
||||
|
||||
// Add async-specific stats
|
||||
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||
stats.QueueSize = b.workerPool.QueueSize()
|
||||
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// InstanceID returns the instance ID
|
||||
func (b *EventBroker) InstanceID() string {
|
||||
return b.instanceID
|
||||
}
|
||||
524
pkg/eventbroker/broker_test.go
Normal file
524
pkg/eventbroker/broker_test.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBroker(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 1000,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts Options
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid options",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider",
|
||||
opts: Options{
|
||||
InstanceID: "test-instance",
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing instance ID",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "async mode with defaults",
|
||||
opts: Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
broker, err := NewBroker(tt.opts)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
if err == nil && broker == nil {
|
||||
t.Error("Expected non-nil broker")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerStartStop(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, err := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create broker: %v", err)
|
||||
}
|
||||
|
||||
// Test Start
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to start broker: %v", err)
|
||||
}
|
||||
|
||||
// Test double start (should fail)
|
||||
if err := broker.Start(context.Background()); err == nil {
|
||||
t.Error("Expected error on double start")
|
||||
}
|
||||
|
||||
// Test Stop
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to stop broker: %v", err)
|
||||
}
|
||||
|
||||
// Test double stop (should not fail)
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Error("Double stop should not fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishSync(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe to events
|
||||
called := false
|
||||
var receivedEvent *Event
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
receivedEvent = event
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.PublishSync(context.Background(), event)
|
||||
if err != nil {
|
||||
t.Fatalf("PublishSync failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify handler was called
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
if receivedEvent == nil || receivedEvent.ID != event.ID {
|
||||
t.Error("Expected to receive the published event")
|
||||
}
|
||||
|
||||
// Verify event status
|
||||
if event.Status != EventStatusCompleted {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishAsync(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 2,
|
||||
BufferSize: 10,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe to events
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
if err := broker.PublishAsync(context.Background(), event); err != nil {
|
||||
t.Fatalf("PublishAsync failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for events to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if callCount.Load() != 5 {
|
||||
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerPublishBeforeStart(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.Publish(context.Background(), event)
|
||||
if err == nil {
|
||||
t.Error("Expected error when publishing before start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerHandlerError(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
RetryPolicy: &RetryPolicy{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
},
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe with failing handler
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return errors.New("handler error")
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
err := broker.PublishSync(context.Background(), event)
|
||||
|
||||
// Should fail after retries
|
||||
if err == nil {
|
||||
t.Error("Expected error from handler")
|
||||
}
|
||||
|
||||
// Should have been called MaxRetries+1 times (initial + retries)
|
||||
if callCount.Load() != 3 {
|
||||
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
|
||||
}
|
||||
|
||||
// Event should be marked as failed
|
||||
if event.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerMultipleHandlers(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe multiple handlers
|
||||
var called1, called2, called3 bool
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called1 = true
|
||||
return nil
|
||||
}))
|
||||
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called2 = true
|
||||
return nil
|
||||
}))
|
||||
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called3 = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
|
||||
// All handlers should be called
|
||||
if !called1 || !called2 || !called3 {
|
||||
t.Error("Expected all handlers to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerUnsubscribe(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe
|
||||
called := false
|
||||
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Unsubscribe
|
||||
if err := broker.Unsubscribe(id); err != nil {
|
||||
t.Fatalf("Unsubscribe failed: %v", err)
|
||||
}
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
|
||||
// Handler should not be called
|
||||
if called {
|
||||
t.Error("Expected handler not to be called after unsubscribe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerStats(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeSync,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
// Subscribe
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 3; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishSync(context.Background(), event)
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats, err := broker.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.InstanceID != "test-instance" {
|
||||
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
|
||||
}
|
||||
if stats.TotalPublished != 3 {
|
||||
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
|
||||
}
|
||||
if stats.TotalProcessed != 3 {
|
||||
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
|
||||
}
|
||||
if stats.ActiveSubscribers != 1 {
|
||||
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerInstanceID(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "my-instance",
|
||||
})
|
||||
|
||||
if broker.InstanceID() != "my-instance" {
|
||||
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerConcurrentPublish(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 5,
|
||||
BufferSize: 100,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
var callCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish concurrently
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishAsync(context.Background(), event)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(200 * time.Millisecond) // Wait for async processing
|
||||
|
||||
if callCount.Load() != 50 {
|
||||
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerGracefulShutdown(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 2,
|
||||
BufferSize: 10,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
|
||||
var processedCount atomic.Int32
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||
processedCount.Add(1)
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.PublishAsync(context.Background(), event)
|
||||
}
|
||||
|
||||
// Stop broker (should wait for events to be processed)
|
||||
if err := broker.Stop(context.Background()); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
// All events should be processed
|
||||
if processedCount.Load() != 5 {
|
||||
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerDefaultRetryPolicy(t *testing.T) {
|
||||
policy := DefaultRetryPolicy()
|
||||
|
||||
if policy.MaxRetries != 3 {
|
||||
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
|
||||
}
|
||||
if policy.InitialDelay != 1*time.Second {
|
||||
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
|
||||
}
|
||||
if policy.BackoffFactor != 2.0 {
|
||||
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerProcessingModes(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode ProcessingMode
|
||||
}{
|
||||
{"sync mode", ProcessingModeSync},
|
||||
{"async mode", ProcessingModeAsync},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
broker, _ := NewBroker(Options{
|
||||
Provider: provider,
|
||||
InstanceID: "test-instance",
|
||||
Mode: tt.mode,
|
||||
})
|
||||
broker.Start(context.Background())
|
||||
defer broker.Stop(context.Background())
|
||||
|
||||
called := false
|
||||
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
event := NewEvent(EventSourceSystem, "test.event")
|
||||
event.InstanceID = "test-instance"
|
||||
broker.Publish(context.Background(), event)
|
||||
|
||||
if tt.mode == ProcessingModeAsync {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
pkg/eventbroker/event.go
Normal file
175
pkg/eventbroker/event.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// EventSource represents where an event originated from
|
||||
type EventSource string
|
||||
|
||||
const (
|
||||
EventSourceDatabase EventSource = "database"
|
||||
EventSourceWebSocket EventSource = "websocket"
|
||||
EventSourceFrontend EventSource = "frontend"
|
||||
EventSourceSystem EventSource = "system"
|
||||
EventSourceInternal EventSource = "internal"
|
||||
)
|
||||
|
||||
// EventStatus represents the current state of an event
|
||||
type EventStatus string
|
||||
|
||||
const (
|
||||
EventStatusPending EventStatus = "pending"
|
||||
EventStatusProcessing EventStatus = "processing"
|
||||
EventStatusCompleted EventStatus = "completed"
|
||||
EventStatusFailed EventStatus = "failed"
|
||||
)
|
||||
|
||||
// Event represents a single event in the system with complete metadata
|
||||
type Event struct {
|
||||
// Identification
|
||||
ID string `json:"id" db:"id"`
|
||||
|
||||
// Source & Classification
|
||||
Source EventSource `json:"source" db:"source"`
|
||||
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
|
||||
|
||||
// Status Tracking
|
||||
Status EventStatus `json:"status" db:"status"`
|
||||
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||
Error string `json:"error,omitempty" db:"error"`
|
||||
|
||||
// Payload
|
||||
Payload json.RawMessage `json:"payload" db:"payload"`
|
||||
|
||||
// Context Information
|
||||
UserID int `json:"user_id" db:"user_id"`
|
||||
SessionID string `json:"session_id" db:"session_id"`
|
||||
InstanceID string `json:"instance_id" db:"instance_id"`
|
||||
|
||||
// Database Context
|
||||
Schema string `json:"schema" db:"schema"`
|
||||
Entity string `json:"entity" db:"entity"`
|
||||
Operation string `json:"operation" db:"operation"` // create, update, delete, read
|
||||
|
||||
// Timestamps
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||
|
||||
// Extensibility
|
||||
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
|
||||
}
|
||||
|
||||
// NewEvent creates a new event with defaults
|
||||
func NewEvent(source EventSource, eventType string) *Event {
|
||||
return &Event{
|
||||
ID: uuid.New().String(),
|
||||
Source: source,
|
||||
Type: eventType,
|
||||
Status: EventStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
Metadata: make(map[string]interface{}),
|
||||
RetryCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// EventType generates a type string from schema, entity, and operation
|
||||
// Pattern: schema.entity.operation (e.g., "public.users.create")
|
||||
func EventType(schema, entity, operation string) string {
|
||||
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
|
||||
}
|
||||
|
||||
// MarkProcessing marks the event as being processed
|
||||
func (e *Event) MarkProcessing() {
|
||||
e.Status = EventStatusProcessing
|
||||
now := time.Now()
|
||||
e.ProcessedAt = &now
|
||||
}
|
||||
|
||||
// MarkCompleted marks the event as successfully completed
|
||||
func (e *Event) MarkCompleted() {
|
||||
e.Status = EventStatusCompleted
|
||||
now := time.Now()
|
||||
e.CompletedAt = &now
|
||||
}
|
||||
|
||||
// MarkFailed marks the event as failed with an error message
|
||||
func (e *Event) MarkFailed(err error) {
|
||||
e.Status = EventStatusFailed
|
||||
e.Error = err.Error()
|
||||
now := time.Now()
|
||||
e.CompletedAt = &now
|
||||
}
|
||||
|
||||
// IncrementRetry increments the retry counter
|
||||
func (e *Event) IncrementRetry() {
|
||||
e.RetryCount++
|
||||
}
|
||||
|
||||
// SetPayload sets the event payload from any value by marshaling to JSON
|
||||
func (e *Event) SetPayload(v interface{}) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
e.Payload = data
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPayload unmarshals the payload into the provided value
|
||||
func (e *Event) GetPayload(v interface{}) error {
|
||||
if len(e.Payload) == 0 {
|
||||
return fmt.Errorf("payload is empty")
|
||||
}
|
||||
if err := json.Unmarshal(e.Payload, v); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal payload: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the event
|
||||
func (e *Event) Clone() *Event {
|
||||
clone := *e
|
||||
|
||||
// Deep copy metadata
|
||||
if e.Metadata != nil {
|
||||
clone.Metadata = make(map[string]interface{})
|
||||
for k, v := range e.Metadata {
|
||||
clone.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy timestamps
|
||||
if e.ProcessedAt != nil {
|
||||
t := *e.ProcessedAt
|
||||
clone.ProcessedAt = &t
|
||||
}
|
||||
if e.CompletedAt != nil {
|
||||
t := *e.CompletedAt
|
||||
clone.CompletedAt = &t
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// Validate performs basic validation on the event
|
||||
func (e *Event) Validate() error {
|
||||
if e.ID == "" {
|
||||
return fmt.Errorf("event ID is required")
|
||||
}
|
||||
if e.Source == "" {
|
||||
return fmt.Errorf("event source is required")
|
||||
}
|
||||
if e.Type == "" {
|
||||
return fmt.Errorf("event type is required")
|
||||
}
|
||||
if e.InstanceID == "" {
|
||||
return fmt.Errorf("instance ID is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
314
pkg/eventbroker/event_test.go
Normal file
314
pkg/eventbroker/event_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewEvent(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
if event.ID == "" {
|
||||
t.Error("Expected event ID to be generated")
|
||||
}
|
||||
if event.Source != EventSourceDatabase {
|
||||
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
|
||||
}
|
||||
if event.Type != "public.users.create" {
|
||||
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
|
||||
}
|
||||
if event.Status != EventStatusPending {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
|
||||
}
|
||||
if event.CreatedAt.IsZero() {
|
||||
t.Error("Expected CreatedAt to be set")
|
||||
}
|
||||
if event.Metadata == nil {
|
||||
t.Error("Expected Metadata to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventType(t *testing.T) {
|
||||
tests := []struct {
|
||||
schema string
|
||||
entity string
|
||||
operation string
|
||||
expected string
|
||||
}{
|
||||
{"public", "users", "create", "public.users.create"},
|
||||
{"admin", "roles", "update", "admin.roles.update"},
|
||||
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := EventType(tt.schema, tt.entity, tt.operation)
|
||||
if result != tt.expected {
|
||||
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
|
||||
tt.schema, tt.entity, tt.operation, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
event *Event
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "valid event",
|
||||
event: func() *Event {
|
||||
e := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
e.InstanceID = "test-instance"
|
||||
return e
|
||||
}(),
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "missing ID",
|
||||
event: &Event{
|
||||
Source: EventSourceDatabase,
|
||||
Type: "public.users.create",
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing source",
|
||||
event: &Event{
|
||||
ID: "test-id",
|
||||
Type: "public.users.create",
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
event: &Event{
|
||||
ID: "test-id",
|
||||
Source: EventSourceDatabase,
|
||||
Status: EventStatusPending,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.event.Validate()
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetPayload(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
}
|
||||
|
||||
err := event.SetPayload(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("SetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
if event.Payload == nil {
|
||||
t.Fatal("Expected payload to be set")
|
||||
}
|
||||
|
||||
// Verify payload can be unmarshaled
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(event.Payload, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal payload: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "John Doe" {
|
||||
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventGetPayload(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"id": float64(1), // JSON unmarshals numbers as float64
|
||||
"name": "John Doe",
|
||||
}
|
||||
|
||||
if err := event.SetPayload(payload); err != nil {
|
||||
t.Fatalf("SetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := event.GetPayload(&result); err != nil {
|
||||
t.Fatalf("GetPayload failed: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "John Doe" {
|
||||
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkProcessing(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.MarkProcessing()
|
||||
|
||||
if event.Status != EventStatusProcessing {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
|
||||
}
|
||||
if event.ProcessedAt == nil {
|
||||
t.Error("Expected ProcessedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkCompleted(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.MarkCompleted()
|
||||
|
||||
if event.Status != EventStatusCompleted {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||
}
|
||||
if event.CompletedAt == nil {
|
||||
t.Error("Expected CompletedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMarkFailed(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
testErr := errors.New("test error")
|
||||
event.MarkFailed(testErr)
|
||||
|
||||
if event.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||
}
|
||||
if event.Error != "test error" {
|
||||
t.Errorf("Expected error %q, got %q", "test error", event.Error)
|
||||
}
|
||||
if event.CompletedAt == nil {
|
||||
t.Error("Expected CompletedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventIncrementRetry(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
initialCount := event.RetryCount
|
||||
event.IncrementRetry()
|
||||
|
||||
if event.RetryCount != initialCount+1 {
|
||||
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventJSONMarshaling(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.UserID = 123
|
||||
event.SessionID = "session-123"
|
||||
event.InstanceID = "instance-1"
|
||||
event.Schema = "public"
|
||||
event.Entity = "users"
|
||||
event.Operation = "create"
|
||||
event.SetPayload(map[string]interface{}{"name": "Test"})
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal event: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded Event
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal event: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.ID != event.ID {
|
||||
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
|
||||
}
|
||||
if decoded.Source != event.Source {
|
||||
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
|
||||
}
|
||||
if decoded.UserID != event.UserID {
|
||||
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStatusString(t *testing.T) {
|
||||
statuses := []EventStatus{
|
||||
EventStatusPending,
|
||||
EventStatusProcessing,
|
||||
EventStatusCompleted,
|
||||
EventStatusFailed,
|
||||
}
|
||||
|
||||
for _, status := range statuses {
|
||||
if string(status) == "" {
|
||||
t.Errorf("EventStatus %v has empty string representation", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSourceString(t *testing.T) {
|
||||
sources := []EventSource{
|
||||
EventSourceDatabase,
|
||||
EventSourceWebSocket,
|
||||
EventSourceFrontend,
|
||||
EventSourceSystem,
|
||||
EventSourceInternal,
|
||||
}
|
||||
|
||||
for _, source := range sources {
|
||||
if string(source) == "" {
|
||||
t.Errorf("EventSource %v has empty string representation", source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMetadata(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
|
||||
// Test setting metadata
|
||||
event.Metadata["key1"] = "value1"
|
||||
event.Metadata["key2"] = 123
|
||||
|
||||
if event.Metadata["key1"] != "value1" {
|
||||
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
|
||||
}
|
||||
if event.Metadata["key2"] != 123 {
|
||||
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventTimestamps(t *testing.T) {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
createdAt := event.CreatedAt
|
||||
|
||||
// Wait a tiny bit to ensure timestamps differ
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
event.MarkProcessing()
|
||||
if event.ProcessedAt == nil {
|
||||
t.Fatal("ProcessedAt should be set")
|
||||
}
|
||||
if !event.ProcessedAt.After(createdAt) {
|
||||
t.Error("ProcessedAt should be after CreatedAt")
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
event.MarkCompleted()
|
||||
if event.CompletedAt == nil {
|
||||
t.Fatal("CompletedAt should be set")
|
||||
}
|
||||
if !event.CompletedAt.After(*event.ProcessedAt) {
|
||||
t.Error("CompletedAt should be after ProcessedAt")
|
||||
}
|
||||
}
|
||||
158
pkg/eventbroker/eventbroker.go
Normal file
158
pkg/eventbroker/eventbroker.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultBroker Broker
|
||||
brokerMu sync.RWMutex
|
||||
)
|
||||
|
||||
// Initialize initializes the global event broker from configuration
|
||||
func Initialize(cfg config.EventBrokerConfig) error {
|
||||
if !cfg.Enabled {
|
||||
logger.Info("Event broker is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create provider
|
||||
provider, err := NewProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create provider: %w", err)
|
||||
}
|
||||
|
||||
// Parse mode
|
||||
mode := ProcessingModeAsync
|
||||
if cfg.Mode == "sync" {
|
||||
mode = ProcessingModeSync
|
||||
}
|
||||
|
||||
// Convert retry policy
|
||||
retryPolicy := &RetryPolicy{
|
||||
MaxRetries: cfg.RetryPolicy.MaxRetries,
|
||||
InitialDelay: cfg.RetryPolicy.InitialDelay,
|
||||
MaxDelay: cfg.RetryPolicy.MaxDelay,
|
||||
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
|
||||
}
|
||||
if retryPolicy.MaxRetries == 0 {
|
||||
retryPolicy = DefaultRetryPolicy()
|
||||
}
|
||||
|
||||
// Create broker options
|
||||
opts := Options{
|
||||
Provider: provider,
|
||||
Mode: mode,
|
||||
WorkerCount: cfg.WorkerCount,
|
||||
BufferSize: cfg.BufferSize,
|
||||
RetryPolicy: retryPolicy,
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
}
|
||||
|
||||
// Create broker
|
||||
broker, err := NewBroker(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create broker: %w", err)
|
||||
}
|
||||
|
||||
// Start broker
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
return fmt.Errorf("failed to start broker: %w", err)
|
||||
}
|
||||
|
||||
// Set as default
|
||||
SetDefaultBroker(broker)
|
||||
|
||||
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
|
||||
cfg.Provider, cfg.Mode, opts.InstanceID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultBroker sets the default global broker
|
||||
func SetDefaultBroker(broker Broker) {
|
||||
brokerMu.Lock()
|
||||
defer brokerMu.Unlock()
|
||||
defaultBroker = broker
|
||||
}
|
||||
|
||||
// GetDefaultBroker returns the default global broker
|
||||
func GetDefaultBroker() Broker {
|
||||
brokerMu.RLock()
|
||||
defer brokerMu.RUnlock()
|
||||
return defaultBroker
|
||||
}
|
||||
|
||||
// IsInitialized returns true if the default broker is initialized
|
||||
func IsInitialized() bool {
|
||||
return GetDefaultBroker() != nil
|
||||
}
|
||||
|
||||
// Publish publishes an event using the default broker
|
||||
func Publish(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Publish(ctx, event)
|
||||
}
|
||||
|
||||
// PublishSync publishes an event synchronously using the default broker
|
||||
func PublishSync(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.PublishSync(ctx, event)
|
||||
}
|
||||
|
||||
// PublishAsync publishes an event asynchronously using the default broker
|
||||
func PublishAsync(ctx context.Context, event *Event) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.PublishAsync(ctx, event)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to events using the default broker
|
||||
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return "", fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Subscribe(pattern, handler)
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from events using the default broker
|
||||
func Unsubscribe(id SubscriptionID) error {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Unsubscribe(id)
|
||||
}
|
||||
|
||||
// Stats returns statistics from the default broker
|
||||
func Stats(ctx context.Context) (*BrokerStats, error) {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return nil, fmt.Errorf("event broker not initialized")
|
||||
}
|
||||
return broker.Stats(ctx)
|
||||
}
|
||||
|
||||
// RegisterShutdown registers the broker's shutdown with a server manager
|
||||
// Call this from your application initialization code
|
||||
// Example: serverMgr.RegisterShutdownCallback(eventbroker.MakeShutdownCallback(broker))
|
||||
func MakeShutdownCallback(broker Broker) func(context.Context) error {
|
||||
return func(ctx context.Context) error {
|
||||
logger.Info("Shutting down event broker...")
|
||||
return broker.Stop(ctx)
|
||||
}
|
||||
}
|
||||
266
pkg/eventbroker/example_usage.go
Normal file
266
pkg/eventbroker/example_usage.go
Normal file
@@ -0,0 +1,266 @@
|
||||
// nolint
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Example demonstrates basic usage of the event broker
|
||||
func Example() {
|
||||
// 1. Create a memory provider
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "example-instance",
|
||||
MaxEvents: 1000,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
MaxAge: 1 * time.Hour,
|
||||
})
|
||||
|
||||
// 2. Create a broker
|
||||
broker, err := NewBroker(Options{
|
||||
Provider: provider,
|
||||
Mode: ProcessingModeAsync,
|
||||
WorkerCount: 5,
|
||||
BufferSize: 100,
|
||||
RetryPolicy: DefaultRetryPolicy(),
|
||||
InstanceID: "example-instance",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to create broker: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Start the broker
|
||||
if err := broker.Start(context.Background()); err != nil {
|
||||
logger.Error("Failed to start broker: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := broker.Stop(context.Background())
|
||||
if err != nil {
|
||||
logger.Error("Failed to stop broker: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 4. Subscribe to events
|
||||
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// 5. Publish events
|
||||
ctx := context.Background()
|
||||
|
||||
// Database event
|
||||
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
|
||||
dbEvent.InstanceID = "example-instance"
|
||||
dbEvent.UserID = 123
|
||||
dbEvent.SessionID = "session-456"
|
||||
dbEvent.Schema = "public"
|
||||
dbEvent.Entity = "users"
|
||||
dbEvent.Operation = "create"
|
||||
dbEvent.SetPayload(map[string]interface{}{
|
||||
"id": 123,
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
})
|
||||
|
||||
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
|
||||
logger.Error("Failed to publish event: %v", err)
|
||||
}
|
||||
|
||||
// WebSocket event
|
||||
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
|
||||
wsEvent.InstanceID = "example-instance"
|
||||
wsEvent.UserID = 123
|
||||
wsEvent.SessionID = "session-456"
|
||||
wsEvent.SetPayload(map[string]interface{}{
|
||||
"room": "general",
|
||||
"message": "Hello, World!",
|
||||
})
|
||||
|
||||
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
|
||||
logger.Error("Failed to publish event: %v", err)
|
||||
}
|
||||
|
||||
// 6. Get statistics
|
||||
time.Sleep(1 * time.Second) // Wait for processing
|
||||
stats, _ := broker.Stats(ctx)
|
||||
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
|
||||
}
|
||||
|
||||
// ExampleWithHooks demonstrates integration with the hook system
|
||||
func ExampleWithHooks() {
|
||||
// This would typically be called in your main.go or initialization code
|
||||
// after setting up your restheadspec.Handler
|
||||
|
||||
// Pseudo-code (actual implementation would use real handler):
|
||||
/*
|
||||
broker := eventbroker.GetDefaultBroker()
|
||||
hookRegistry := handler.Hooks()
|
||||
|
||||
// Register CRUD hooks
|
||||
config := eventbroker.DefaultCRUDHookConfig()
|
||||
config.EnableRead = false // Disable read events for performance
|
||||
|
||||
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
|
||||
logger.Error("Failed to register CRUD hooks: %v", err)
|
||||
}
|
||||
|
||||
// Now all CRUD operations will automatically publish events
|
||||
*/
|
||||
}
|
||||
|
||||
// ExampleSubscriptionPatterns demonstrates different subscription patterns
|
||||
func ExampleSubscriptionPatterns() {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Pattern 1: Subscribe to all events from a specific entity
|
||||
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("User event: %s\n", event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 2: Subscribe to a specific operation across all entities
|
||||
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 3: Subscribe to all events in a schema
|
||||
broker.Subscribe("public.*.*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
// Pattern 4: Subscribe to everything (use with caution)
|
||||
broker.Subscribe("*", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
fmt.Printf("Any event: %s\n", event.Type)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// ExampleErrorHandling demonstrates error handling in event handlers
|
||||
func ExampleErrorHandling() {
|
||||
broker := GetDefaultBroker()
|
||||
if broker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handler that may fail
|
||||
broker.Subscribe("public.users.create", EventHandlerFunc(
|
||||
func(ctx context.Context, event *Event) error {
|
||||
// Simulate processing
|
||||
var user struct {
|
||||
ID int `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if err := event.GetPayload(&user); err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
// Validate
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
// Process (e.g., send email)
|
||||
logger.Info("Sending welcome email to %s", user.Email)
|
||||
|
||||
return nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// ExampleConfiguration demonstrates initializing from configuration
|
||||
func ExampleConfiguration() {
|
||||
// This would typically be in your main.go
|
||||
|
||||
// Pseudo-code:
|
||||
/*
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
if err := cfgMgr.Load(); err != nil {
|
||||
logger.Fatal("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := cfgMgr.GetConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Initialize event broker
|
||||
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||
logger.Fatal("Failed to initialize event broker: %v", err)
|
||||
}
|
||||
|
||||
// Use the default broker
|
||||
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
|
||||
func(ctx context.Context, event *eventbroker.Event) error {
|
||||
logger.Info("Created: %s.%s", event.Schema, event.Entity)
|
||||
return nil
|
||||
},
|
||||
))
|
||||
*/
|
||||
}
|
||||
|
||||
// ExampleYAMLConfiguration shows example YAML configuration
|
||||
const ExampleYAMLConfiguration = `
|
||||
event_broker:
|
||||
enabled: true
|
||||
provider: memory # memory, redis, nats, database
|
||||
mode: async # sync, async
|
||||
worker_count: 10
|
||||
buffer_size: 1000
|
||||
instance_id: "${HOSTNAME}"
|
||||
|
||||
# Memory provider is default, no additional config needed
|
||||
|
||||
# Redis provider (when provider: redis)
|
||||
redis:
|
||||
stream_name: "resolvespec:events"
|
||||
consumer_group: "resolvespec-workers"
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
|
||||
# NATS provider (when provider: nats)
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "RESOLVESPEC_EVENTS"
|
||||
|
||||
# Database provider (when provider: database)
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "resolvespec_events"
|
||||
|
||||
# Retry policy
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 30s
|
||||
backoff_factor: 2.0
|
||||
`
|
||||
74
pkg/eventbroker/factory.go
Normal file
74
pkg/eventbroker/factory.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// NewProviderFromConfig creates a provider based on configuration
|
||||
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "memory":
|
||||
cleanupInterval := 5 * time.Minute
|
||||
if cfg.Database.PollInterval > 0 {
|
||||
cleanupInterval = cfg.Database.PollInterval
|
||||
}
|
||||
|
||||
return NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxEvents: 10000,
|
||||
CleanupInterval: cleanupInterval,
|
||||
}), nil
|
||||
|
||||
case "redis":
|
||||
return NewRedisProvider(RedisProviderConfig{
|
||||
Host: cfg.Redis.Host,
|
||||
Port: cfg.Redis.Port,
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
StreamName: cfg.Redis.StreamName,
|
||||
ConsumerGroup: cfg.Redis.ConsumerGroup,
|
||||
ConsumerName: getInstanceID(cfg.InstanceID),
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxLen: cfg.Redis.MaxLen,
|
||||
})
|
||||
|
||||
case "nats":
|
||||
// NATS provider initialization
|
||||
// Note: Requires github.com/nats-io/nats.go dependency
|
||||
return NewNATSProvider(NATSProviderConfig{
|
||||
URL: cfg.NATS.URL,
|
||||
StreamName: cfg.NATS.StreamName,
|
||||
SubjectPrefix: "events",
|
||||
InstanceID: getInstanceID(cfg.InstanceID),
|
||||
MaxAge: cfg.NATS.MaxAge,
|
||||
Storage: cfg.NATS.Storage, // "file" or "memory"
|
||||
})
|
||||
|
||||
case "database":
|
||||
// Database provider requires a database connection
|
||||
// This should be provided externally
|
||||
return nil, fmt.Errorf("database provider requires a database connection to be configured separately")
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// getInstanceID returns the instance ID, defaulting to hostname if not specified
|
||||
func getInstanceID(configID string) string {
|
||||
if configID != "" {
|
||||
return configID
|
||||
}
|
||||
|
||||
// Try to get hostname
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname
|
||||
}
|
||||
|
||||
// Fallback to a default
|
||||
return "resolvespec-instance"
|
||||
}
|
||||
17
pkg/eventbroker/handler.go
Normal file
17
pkg/eventbroker/handler.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package eventbroker
|
||||
|
||||
import "context"
|
||||
|
||||
// EventHandler processes an event
|
||||
type EventHandler interface {
|
||||
Handle(ctx context.Context, event *Event) error
|
||||
}
|
||||
|
||||
// EventHandlerFunc is a function adapter for EventHandler
|
||||
// This allows using regular functions as event handlers
|
||||
type EventHandlerFunc func(ctx context.Context, event *Event) error
|
||||
|
||||
// Handle implements EventHandler
|
||||
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
|
||||
return f(ctx, event)
|
||||
}
|
||||
137
pkg/eventbroker/hooks.go
Normal file
137
pkg/eventbroker/hooks.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// CRUDHookConfig configures which CRUD operations should trigger events
|
||||
type CRUDHookConfig struct {
|
||||
EnableCreate bool
|
||||
EnableRead bool
|
||||
EnableUpdate bool
|
||||
EnableDelete bool
|
||||
}
|
||||
|
||||
// DefaultCRUDHookConfig returns default configuration (all enabled)
|
||||
func DefaultCRUDHookConfig() *CRUDHookConfig {
|
||||
return &CRUDHookConfig{
|
||||
EnableCreate: true,
|
||||
EnableRead: false, // Typically disabled for performance
|
||||
EnableUpdate: true,
|
||||
EnableDelete: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterCRUDHooks registers event hooks for CRUD operations
|
||||
// This integrates with the restheadspec.HookRegistry to automatically
|
||||
// capture database events
|
||||
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
|
||||
if broker == nil {
|
||||
return fmt.Errorf("broker cannot be nil")
|
||||
}
|
||||
if hookRegistry == nil {
|
||||
return fmt.Errorf("hookRegistry cannot be nil")
|
||||
}
|
||||
if config == nil {
|
||||
config = DefaultCRUDHookConfig()
|
||||
}
|
||||
|
||||
// Create hook handler factory
|
||||
createHookHandler := func(operation string) restheadspec.HookFunc {
|
||||
return func(hookCtx *restheadspec.HookContext) error {
|
||||
// Get user context from Go context
|
||||
userCtx, ok := security.GetUserContext(hookCtx.Context)
|
||||
if !ok || userCtx == nil {
|
||||
logger.Debug("No user context found in hook")
|
||||
userCtx = &security.UserContext{} // Empty user context
|
||||
}
|
||||
|
||||
// Create event
|
||||
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
|
||||
event.InstanceID = broker.InstanceID()
|
||||
event.UserID = userCtx.UserID
|
||||
event.SessionID = userCtx.SessionID
|
||||
event.Schema = hookCtx.Schema
|
||||
event.Entity = hookCtx.Entity
|
||||
event.Operation = operation
|
||||
|
||||
// Set payload based on operation
|
||||
var payload interface{}
|
||||
switch operation {
|
||||
case "create":
|
||||
payload = hookCtx.Result
|
||||
case "read":
|
||||
payload = hookCtx.Result
|
||||
case "update":
|
||||
payload = map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
"data": hookCtx.Data,
|
||||
}
|
||||
case "delete":
|
||||
payload = map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
}
|
||||
}
|
||||
|
||||
if payload != nil {
|
||||
if err := event.SetPayload(payload); err != nil {
|
||||
logger.Error("Failed to set event payload: %v", err)
|
||||
payload = map[string]interface{}{"error": "failed to serialize payload"}
|
||||
event.Payload, _ = json.Marshal(payload)
|
||||
}
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
if userCtx.UserName != "" {
|
||||
event.Metadata["user_name"] = userCtx.UserName
|
||||
}
|
||||
if userCtx.Email != "" {
|
||||
event.Metadata["user_email"] = userCtx.Email
|
||||
}
|
||||
if len(userCtx.Roles) > 0 {
|
||||
event.Metadata["user_roles"] = userCtx.Roles
|
||||
}
|
||||
event.Metadata["table_name"] = hookCtx.TableName
|
||||
|
||||
// Publish asynchronously to not block CRUD operation
|
||||
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
|
||||
logger.Error("Failed to publish %s event for %s.%s: %v",
|
||||
operation, hookCtx.Schema, hookCtx.Entity, err)
|
||||
// Don't fail the CRUD operation if event publishing fails
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Published %s event for %s.%s (ID: %s)",
|
||||
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Register hooks based on configuration
|
||||
if config.EnableCreate {
|
||||
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
|
||||
logger.Info("Registered event hook for CREATE operations")
|
||||
}
|
||||
|
||||
if config.EnableRead {
|
||||
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
|
||||
logger.Info("Registered event hook for READ operations")
|
||||
}
|
||||
|
||||
if config.EnableUpdate {
|
||||
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
|
||||
logger.Info("Registered event hook for UPDATE operations")
|
||||
}
|
||||
|
||||
if config.EnableDelete {
|
||||
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
|
||||
logger.Info("Registered event hook for DELETE operations")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
28
pkg/eventbroker/metrics.go
Normal file
28
pkg/eventbroker/metrics.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
// recordEventPublished records an event publication metric
|
||||
func recordEventPublished(event *Event) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.RecordEventPublished(string(event.Source), event.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// recordEventProcessed records an event processing metric
|
||||
func recordEventProcessed(event *Event, duration time.Duration) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
|
||||
}
|
||||
}
|
||||
|
||||
// updateQueueSize updates the event queue size metric
|
||||
func updateQueueSize(size int64) {
|
||||
if mp := metrics.GetProvider(); mp != nil {
|
||||
mp.UpdateEventQueueSize(size)
|
||||
}
|
||||
}
|
||||
70
pkg/eventbroker/provider.go
Normal file
70
pkg/eventbroker/provider.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider defines the storage backend interface for events
|
||||
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
|
||||
type Provider interface {
|
||||
// Store stores an event
|
||||
Store(ctx context.Context, event *Event) error
|
||||
|
||||
// Get retrieves an event by ID
|
||||
Get(ctx context.Context, id string) (*Event, error)
|
||||
|
||||
// List lists events with optional filters
|
||||
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
|
||||
|
||||
// Delete deletes an event by ID
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Used for cross-instance pub/sub
|
||||
// The channel is closed when the context is canceled or an error occurs
|
||||
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||
|
||||
// Publish publishes an event to all subscribers (for distributed providers)
|
||||
// For in-memory provider, this is the same as Store
|
||||
// For Redis/NATS/Database, this triggers cross-instance delivery
|
||||
Publish(ctx context.Context, event *Event) error
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
Close() error
|
||||
|
||||
// Stats returns provider statistics
|
||||
Stats(ctx context.Context) (*ProviderStats, error)
|
||||
}
|
||||
|
||||
// EventFilter defines filter criteria for listing events
|
||||
type EventFilter struct {
|
||||
Source *EventSource
|
||||
Status *EventStatus
|
||||
UserID *int
|
||||
Schema string
|
||||
Entity string
|
||||
Operation string
|
||||
InstanceID string
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ProviderStats contains statistics about the provider
|
||||
type ProviderStats struct {
|
||||
ProviderType string `json:"provider_type"`
|
||||
TotalEvents int64 `json:"total_events"`
|
||||
PendingEvents int64 `json:"pending_events"`
|
||||
ProcessingEvents int64 `json:"processing_events"`
|
||||
CompletedEvents int64 `json:"completed_events"`
|
||||
FailedEvents int64 `json:"failed_events"`
|
||||
EventsPublished int64 `json:"events_published"`
|
||||
EventsConsumed int64 `json:"events_consumed"`
|
||||
ActiveSubscribers int `json:"active_subscribers"`
|
||||
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
|
||||
}
|
||||
653
pkg/eventbroker/provider_database.go
Normal file
653
pkg/eventbroker/provider_database.go
Normal file
@@ -0,0 +1,653 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// DatabaseProvider implements Provider interface using SQL database
|
||||
// Features:
|
||||
// - Persistent event storage in database table
|
||||
// - Full SQL query support for event history
|
||||
// - PostgreSQL NOTIFY/LISTEN for real-time updates (optional)
|
||||
// - Polling-based consumption with configurable interval
|
||||
// - Good for audit trails and event replay
|
||||
type DatabaseProvider struct {
|
||||
db common.Database
|
||||
tableName string
|
||||
channel string // PostgreSQL NOTIFY channel name
|
||||
pollInterval time.Duration
|
||||
instanceID string
|
||||
useNotify bool // Whether to use PostgreSQL NOTIFY
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*dbSubscription
|
||||
|
||||
// Statistics
|
||||
stats DatabaseProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopPolling chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// DatabaseProviderStats contains statistics for the database provider
|
||||
type DatabaseProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
PollErrors atomic.Int64
|
||||
}
|
||||
|
||||
// dbSubscription represents a single database subscription
|
||||
type dbSubscription struct {
|
||||
pattern string
|
||||
ch chan *Event
|
||||
lastSeenID string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DatabaseProviderConfig configures the database provider
|
||||
type DatabaseProviderConfig struct {
|
||||
DB common.Database
|
||||
TableName string
|
||||
Channel string // PostgreSQL NOTIFY channel (optional)
|
||||
PollInterval time.Duration
|
||||
InstanceID string
|
||||
UseNotify bool // Enable PostgreSQL NOTIFY/LISTEN
|
||||
}
|
||||
|
||||
// NewDatabaseProvider creates a new database event provider
|
||||
func NewDatabaseProvider(cfg DatabaseProviderConfig) (*DatabaseProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.TableName == "" {
|
||||
cfg.TableName = "events"
|
||||
}
|
||||
if cfg.Channel == "" {
|
||||
cfg.Channel = "resolvespec_events"
|
||||
}
|
||||
if cfg.PollInterval == 0 {
|
||||
cfg.PollInterval = 1 * time.Second
|
||||
}
|
||||
|
||||
dp := &DatabaseProvider{
|
||||
db: cfg.DB,
|
||||
tableName: cfg.TableName,
|
||||
channel: cfg.Channel,
|
||||
pollInterval: cfg.PollInterval,
|
||||
instanceID: cfg.InstanceID,
|
||||
useNotify: cfg.UseNotify,
|
||||
subscribers: make(map[string]*dbSubscription),
|
||||
stopPolling: make(chan struct{}),
|
||||
}
|
||||
|
||||
dp.isRunning.Store(true)
|
||||
|
||||
// Create table if it doesn't exist
|
||||
ctx := context.Background()
|
||||
if err := dp.createTable(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to create events table: %w", err)
|
||||
}
|
||||
|
||||
// Start polling goroutine for subscriptions
|
||||
dp.wg.Add(1)
|
||||
go dp.pollLoop()
|
||||
|
||||
logger.Info("Database provider initialized (table: %s, poll_interval: %v, notify: %v)",
|
||||
cfg.TableName, cfg.PollInterval, cfg.UseNotify)
|
||||
|
||||
return dp, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (dp *DatabaseProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal metadata to JSON
|
||||
metadataJSON, err := json.Marshal(event.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Insert event
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (
|
||||
id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10,
|
||||
$11, $12, $13,
|
||||
$14, $15, $16, $17
|
||||
)
|
||||
`, dp.tableName)
|
||||
|
||||
_, err = dp.db.Exec(ctx, query,
|
||||
event.ID, event.Source, event.Type, event.Status, event.RetryCount, event.Error,
|
||||
event.Payload, event.UserID, event.SessionID, event.InstanceID,
|
||||
event.Schema, event.Entity, event.Operation,
|
||||
event.CreatedAt, event.ProcessedAt, event.CompletedAt, metadataJSON,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert event: %w", err)
|
||||
}
|
||||
|
||||
dp.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
func (dp *DatabaseProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
event := &Event{}
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
// Query into individual fields
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
FROM %s
|
||||
WHERE id = $1
|
||||
`, dp.tableName)
|
||||
|
||||
var source, eventType, status, operation string
|
||||
|
||||
// Execute raw query
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query event: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
if err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan event: %w", err)
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (dp *DatabaseProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
query := fmt.Sprintf("SELECT id, source, type, status, retry_count, error, "+
|
||||
"payload, user_id, session_id, instance_id, "+
|
||||
"schema, entity, operation, "+
|
||||
"created_at, processed_at, completed_at, metadata "+
|
||||
"FROM %s WHERE 1=1", dp.tableName)
|
||||
|
||||
args := []interface{}{}
|
||||
argNum := 1
|
||||
|
||||
// Build WHERE clause
|
||||
if filter != nil {
|
||||
if filter.Source != nil {
|
||||
query += fmt.Sprintf(" AND source = $%d", argNum)
|
||||
args = append(args, string(*filter.Source))
|
||||
argNum++
|
||||
}
|
||||
if filter.Status != nil {
|
||||
query += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*filter.Status))
|
||||
argNum++
|
||||
}
|
||||
if filter.UserID != nil {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", argNum)
|
||||
args = append(args, *filter.UserID)
|
||||
argNum++
|
||||
}
|
||||
if filter.Schema != "" {
|
||||
query += fmt.Sprintf(" AND schema = $%d", argNum)
|
||||
args = append(args, filter.Schema)
|
||||
argNum++
|
||||
}
|
||||
if filter.Entity != "" {
|
||||
query += fmt.Sprintf(" AND entity = $%d", argNum)
|
||||
args = append(args, filter.Entity)
|
||||
argNum++
|
||||
}
|
||||
if filter.Operation != "" {
|
||||
query += fmt.Sprintf(" AND operation = $%d", argNum)
|
||||
args = append(args, filter.Operation)
|
||||
argNum++
|
||||
}
|
||||
if filter.InstanceID != "" {
|
||||
query += fmt.Sprintf(" AND instance_id = $%d", argNum)
|
||||
args = append(args, filter.InstanceID)
|
||||
argNum++
|
||||
}
|
||||
if filter.StartTime != nil {
|
||||
query += fmt.Sprintf(" AND created_at >= $%d", argNum)
|
||||
args = append(args, *filter.StartTime)
|
||||
argNum++
|
||||
}
|
||||
if filter.EndTime != nil {
|
||||
query += fmt.Sprintf(" AND created_at <= $%d", argNum)
|
||||
args = append(args, *filter.EndTime)
|
||||
argNum++
|
||||
}
|
||||
}
|
||||
|
||||
// Add ORDER BY
|
||||
query += " ORDER BY created_at DESC"
|
||||
|
||||
// Add LIMIT and OFFSET
|
||||
if filter != nil {
|
||||
if filter.Limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT $%d", argNum)
|
||||
args = append(args, filter.Limit)
|
||||
argNum++
|
||||
}
|
||||
if filter.Offset > 0 {
|
||||
query += fmt.Sprintf(" OFFSET $%d", argNum)
|
||||
args = append(args, filter.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*Event
|
||||
for rows.Next() {
|
||||
event := &Event{}
|
||||
var source, eventType, status, operation string
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to scan event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, event)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
func (dp *DatabaseProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE %s
|
||||
SET status = $1, error = $2
|
||||
WHERE id = $3
|
||||
`, dp.tableName)
|
||||
|
||||
_, err := dp.db.Exec(ctx, query, string(status), errorMsg, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
func (dp *DatabaseProvider) Delete(ctx context.Context, id string) error {
|
||||
query := fmt.Sprintf("DELETE FROM %s WHERE id = $1", dp.tableName)
|
||||
|
||||
_, err := dp.db.Exec(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete event: %w", err)
|
||||
}
|
||||
|
||||
dp.stats.TotalEvents.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
func (dp *DatabaseProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &dbSubscription{
|
||||
pattern: pattern,
|
||||
ch: ch,
|
||||
lastSeenID: "",
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
dp.mu.Lock()
|
||||
dp.subscribers[pattern] = sub
|
||||
dp.stats.ActiveSubscribers.Add(1)
|
||||
dp.mu.Unlock()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (dp *DatabaseProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := dp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dp.stats.EventsPublished.Add(1)
|
||||
|
||||
// If using PostgreSQL NOTIFY, send notification
|
||||
if dp.useNotify {
|
||||
if err := dp.notify(ctx, event.ID); err != nil {
|
||||
logger.Warn("Failed to send NOTIFY: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (dp *DatabaseProvider) Close() error {
|
||||
if !dp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dp.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
dp.mu.Lock()
|
||||
for _, sub := range dp.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
dp.mu.Unlock()
|
||||
|
||||
// Stop polling
|
||||
close(dp.stopPolling)
|
||||
|
||||
// Wait for goroutines
|
||||
dp.wg.Wait()
|
||||
|
||||
logger.Info("Database provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (dp *DatabaseProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
// Get counts by status
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending,
|
||||
COUNT(*) FILTER (WHERE status = 'processing') as processing,
|
||||
COUNT(*) FILTER (WHERE status = 'completed') as completed,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') as failed,
|
||||
COUNT(*) as total
|
||||
FROM %s
|
||||
`, dp.tableName)
|
||||
|
||||
var pending, processing, completed, failed, total int64
|
||||
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get stats: %v", err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&pending, &processing, &completed, &failed, &total); err != nil {
|
||||
logger.Warn("Failed to scan stats: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ProviderStats{
|
||||
ProviderType: "database",
|
||||
TotalEvents: total,
|
||||
PendingEvents: pending,
|
||||
ProcessingEvents: processing,
|
||||
CompletedEvents: completed,
|
||||
FailedEvents: failed,
|
||||
EventsPublished: dp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: dp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(dp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"table_name": dp.tableName,
|
||||
"poll_interval": dp.pollInterval.String(),
|
||||
"use_notify": dp.useNotify,
|
||||
"poll_errors": dp.stats.PollErrors.Load(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pollLoop periodically polls for new events
|
||||
func (dp *DatabaseProvider) pollLoop() {
|
||||
defer dp.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(dp.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
dp.pollEvents()
|
||||
case <-dp.stopPolling:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollEvents polls for new events and delivers to subscribers
|
||||
func (dp *DatabaseProvider) pollEvents() {
|
||||
dp.mu.RLock()
|
||||
subscribers := make([]*dbSubscription, 0, len(dp.subscribers))
|
||||
for _, sub := range dp.subscribers {
|
||||
subscribers = append(subscribers, sub)
|
||||
}
|
||||
dp.mu.RUnlock()
|
||||
|
||||
for _, sub := range subscribers {
|
||||
// Query for new events since last seen
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, source, type, status, retry_count, error,
|
||||
payload, user_id, session_id, instance_id,
|
||||
schema, entity, operation,
|
||||
created_at, processed_at, completed_at, metadata
|
||||
FROM %s
|
||||
WHERE id > $1
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 100
|
||||
`, dp.tableName)
|
||||
|
||||
lastSeenID := sub.lastSeenID
|
||||
if lastSeenID == "" {
|
||||
lastSeenID = "00000000-0000-0000-0000-000000000000"
|
||||
}
|
||||
|
||||
rows, err := dp.db.GetUnderlyingDB().(interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}).QueryContext(sub.ctx, query, lastSeenID)
|
||||
if err != nil {
|
||||
dp.stats.PollErrors.Add(1)
|
||||
logger.Warn("Failed to poll events: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
event := &Event{}
|
||||
var source, eventType, status, operation string
|
||||
var metadataJSON []byte
|
||||
var processedAt, completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&event.ID, &source, &eventType, &status, &event.RetryCount, &event.Error,
|
||||
&event.Payload, &event.UserID, &event.SessionID, &event.InstanceID,
|
||||
&event.Schema, &event.Entity, &operation,
|
||||
&event.CreatedAt, &processedAt, &completedAt, &metadataJSON,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to scan event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set enum values
|
||||
event.Source = EventSource(source)
|
||||
event.Type = eventType
|
||||
event.Status = EventStatus(status)
|
||||
event.Operation = operation
|
||||
|
||||
// Handle nullable timestamps
|
||||
if processedAt.Valid {
|
||||
event.ProcessedAt = &processedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
event.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
// Unmarshal metadata
|
||||
if len(metadataJSON) > 0 {
|
||||
if err := json.Unmarshal(metadataJSON, &event.Metadata); err != nil {
|
||||
logger.Warn("Failed to unmarshal metadata: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if event matches pattern
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- event:
|
||||
dp.stats.EventsConsumed.Add(1)
|
||||
sub.lastSeenID = event.ID
|
||||
case <-sub.ctx.Done():
|
||||
rows.Close()
|
||||
return
|
||||
default:
|
||||
// Channel full, skip
|
||||
logger.Warn("Subscriber channel full for pattern: %s", sub.pattern)
|
||||
}
|
||||
}
|
||||
|
||||
sub.lastSeenID = event.ID
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends a PostgreSQL NOTIFY message
|
||||
func (dp *DatabaseProvider) notify(ctx context.Context, eventID string) error {
|
||||
query := fmt.Sprintf("NOTIFY %s, '%s'", dp.channel, eventID)
|
||||
_, err := dp.db.Exec(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// createTable creates the events table if it doesn't exist
|
||||
func (dp *DatabaseProvider) createTable(ctx context.Context) error {
|
||||
query := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
source VARCHAR(50) NOT NULL,
|
||||
type VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
error TEXT,
|
||||
payload JSONB,
|
||||
user_id INTEGER,
|
||||
session_id VARCHAR(255),
|
||||
instance_id VARCHAR(255),
|
||||
schema VARCHAR(255),
|
||||
entity VARCHAR(255),
|
||||
operation VARCHAR(50),
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
processed_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
metadata JSONB
|
||||
)
|
||||
`, dp.tableName)
|
||||
|
||||
if _, err := dp.db.Exec(ctx, query); err != nil {
|
||||
return fmt.Errorf("failed to create table: %w", err)
|
||||
}
|
||||
|
||||
// Create indexes
|
||||
indexes := []string{
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_source ON %s(source)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_type ON %s(type)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_status ON %s(status)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_created_at ON %s(created_at)", dp.tableName, dp.tableName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS idx_%s_instance_id ON %s(instance_id)", dp.tableName, dp.tableName),
|
||||
}
|
||||
|
||||
for _, indexQuery := range indexes {
|
||||
if _, err := dp.db.Exec(ctx, indexQuery); err != nil {
|
||||
logger.Warn("Failed to create index: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
446
pkg/eventbroker/provider_memory.go
Normal file
446
pkg/eventbroker/provider_memory.go
Normal file
@@ -0,0 +1,446 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// MemoryProvider implements Provider interface using in-memory storage
|
||||
// Features:
|
||||
// - Thread-safe event storage with RW mutex
|
||||
// - LRU eviction when max events reached
|
||||
// - In-process pub/sub (not cross-instance)
|
||||
// - Automatic cleanup of old completed events
|
||||
type MemoryProvider struct {
|
||||
mu sync.RWMutex
|
||||
events map[string]*Event
|
||||
eventOrder []string // For LRU tracking
|
||||
subscribers map[string][]chan *Event
|
||||
instanceID string
|
||||
maxEvents int
|
||||
cleanupInterval time.Duration
|
||||
maxAge time.Duration
|
||||
|
||||
// Statistics
|
||||
stats MemoryProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopCleanup chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// MemoryProviderStats contains statistics for the memory provider
|
||||
type MemoryProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
PendingEvents atomic.Int64
|
||||
ProcessingEvents atomic.Int64
|
||||
CompletedEvents atomic.Int64
|
||||
FailedEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
Evictions atomic.Int64
|
||||
}
|
||||
|
||||
// MemoryProviderOptions configures the memory provider
|
||||
type MemoryProviderOptions struct {
|
||||
InstanceID string
|
||||
MaxEvents int
|
||||
CleanupInterval time.Duration
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
// NewMemoryProvider creates a new in-memory event provider
|
||||
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
|
||||
if opts.MaxEvents == 0 {
|
||||
opts.MaxEvents = 10000 // Default
|
||||
}
|
||||
if opts.CleanupInterval == 0 {
|
||||
opts.CleanupInterval = 5 * time.Minute // Default
|
||||
}
|
||||
if opts.MaxAge == 0 {
|
||||
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
|
||||
}
|
||||
|
||||
mp := &MemoryProvider{
|
||||
events: make(map[string]*Event),
|
||||
eventOrder: make([]string, 0),
|
||||
subscribers: make(map[string][]chan *Event),
|
||||
instanceID: opts.InstanceID,
|
||||
maxEvents: opts.MaxEvents,
|
||||
cleanupInterval: opts.CleanupInterval,
|
||||
maxAge: opts.MaxAge,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
mp.isRunning.Store(true)
|
||||
|
||||
// Start cleanup goroutine
|
||||
mp.wg.Add(1)
|
||||
go mp.cleanupLoop()
|
||||
|
||||
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
|
||||
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Check if we need to evict oldest events
|
||||
if len(mp.events) >= mp.maxEvents {
|
||||
mp.evictOldestLocked()
|
||||
}
|
||||
|
||||
// Store event
|
||||
mp.events[event.ID] = event.Clone()
|
||||
mp.eventOrder = append(mp.eventOrder, event.ID)
|
||||
|
||||
// Update statistics
|
||||
mp.stats.TotalEvents.Add(1)
|
||||
mp.updateStatusCountsLocked(event.Status, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
return event.Clone(), nil
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
var results []*Event
|
||||
|
||||
for _, event := range mp.events {
|
||||
if mp.matchesFilter(event, filter) {
|
||||
results = append(results, event.Clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// Update status counts
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
mp.updateStatusCountsLocked(status, 1)
|
||||
|
||||
// Update event
|
||||
event.Status = status
|
||||
if errorMsg != "" {
|
||||
event.Error = errorMsg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
event, exists := mp.events[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// Update counts
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
|
||||
// Delete event
|
||||
delete(mp.events, id)
|
||||
|
||||
// Remove from order tracking
|
||||
for i, eid := range mp.eventOrder {
|
||||
if eid == id {
|
||||
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Note: This is in-process only, not cross-instance
|
||||
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Create buffered channel for events
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
// Store subscriber
|
||||
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
|
||||
mp.stats.ActiveSubscribers.Add(1)
|
||||
|
||||
// Goroutine to clean up on context cancellation
|
||||
mp.wg.Add(1)
|
||||
go func() {
|
||||
defer mp.wg.Done()
|
||||
<-ctx.Done()
|
||||
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
// Remove subscriber
|
||||
subs := mp.subscribers[pattern]
|
||||
for i, subCh := range subs {
|
||||
if subCh == ch {
|
||||
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mp.stats.ActiveSubscribers.Add(-1)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
logger.Debug("Stream created for pattern: %s", pattern)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := mp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp.stats.EventsPublished.Add(1)
|
||||
|
||||
// Notify subscribers
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
for pattern, channels := range mp.subscribers {
|
||||
if matchPattern(pattern, event.Type) {
|
||||
for _, ch := range channels {
|
||||
select {
|
||||
case ch <- event.Clone():
|
||||
mp.stats.EventsConsumed.Add(1)
|
||||
default:
|
||||
// Channel full, skip
|
||||
logger.Warn("Subscriber channel full for pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (mp *MemoryProvider) Close() error {
|
||||
if !mp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
mp.isRunning.Store(false)
|
||||
|
||||
// Stop cleanup loop
|
||||
close(mp.stopCleanup)
|
||||
|
||||
// Wait for goroutines
|
||||
mp.wg.Wait()
|
||||
|
||||
// Close all subscriber channels
|
||||
mp.mu.Lock()
|
||||
for _, channels := range mp.subscribers {
|
||||
for _, ch := range channels {
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
mp.subscribers = make(map[string][]chan *Event)
|
||||
mp.mu.Unlock()
|
||||
|
||||
logger.Info("Memory provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
return &ProviderStats{
|
||||
ProviderType: "memory",
|
||||
TotalEvents: mp.stats.TotalEvents.Load(),
|
||||
PendingEvents: mp.stats.PendingEvents.Load(),
|
||||
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
|
||||
CompletedEvents: mp.stats.CompletedEvents.Load(),
|
||||
FailedEvents: mp.stats.FailedEvents.Load(),
|
||||
EventsPublished: mp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: mp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"max_events": mp.maxEvents,
|
||||
"cleanup_interval": mp.cleanupInterval.String(),
|
||||
"max_age": mp.maxAge.String(),
|
||||
"evictions": mp.stats.Evictions.Load(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up old completed events
|
||||
func (mp *MemoryProvider) cleanupLoop() {
|
||||
defer mp.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(mp.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mp.cleanup()
|
||||
case <-mp.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old completed/failed events
|
||||
func (mp *MemoryProvider) cleanup() {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-mp.maxAge)
|
||||
removed := 0
|
||||
|
||||
for id, event := range mp.events {
|
||||
// Only clean up completed or failed events that are old
|
||||
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
|
||||
event.CreatedAt.Before(cutoff) {
|
||||
|
||||
delete(mp.events, id)
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
|
||||
// Remove from order tracking
|
||||
for i, eid := range mp.eventOrder {
|
||||
if eid == id {
|
||||
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
logger.Debug("Cleanup removed %d old events", removed)
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestLocked evicts the oldest event (LRU)
|
||||
// Caller must hold write lock
|
||||
func (mp *MemoryProvider) evictOldestLocked() {
|
||||
if len(mp.eventOrder) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get oldest event ID
|
||||
oldestID := mp.eventOrder[0]
|
||||
mp.eventOrder = mp.eventOrder[1:]
|
||||
|
||||
// Remove event
|
||||
if event, exists := mp.events[oldestID]; exists {
|
||||
delete(mp.events, oldestID)
|
||||
mp.stats.TotalEvents.Add(-1)
|
||||
mp.updateStatusCountsLocked(event.Status, -1)
|
||||
mp.stats.Evictions.Add(1)
|
||||
|
||||
logger.Debug("Evicted oldest event: %s", oldestID)
|
||||
}
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// updateStatusCountsLocked updates status statistics
|
||||
// Caller must hold write lock
|
||||
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
|
||||
switch status {
|
||||
case EventStatusPending:
|
||||
mp.stats.PendingEvents.Add(delta)
|
||||
case EventStatusProcessing:
|
||||
mp.stats.ProcessingEvents.Add(delta)
|
||||
case EventStatusCompleted:
|
||||
mp.stats.CompletedEvents.Add(delta)
|
||||
case EventStatusFailed:
|
||||
mp.stats.FailedEvents.Add(delta)
|
||||
}
|
||||
}
|
||||
419
pkg/eventbroker/provider_memory_test.go
Normal file
419
pkg/eventbroker/provider_memory_test.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewMemoryProvider(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 100,
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
})
|
||||
|
||||
if provider == nil {
|
||||
t.Fatal("Expected non-nil provider")
|
||||
}
|
||||
|
||||
stats, err := provider.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.ProviderType != "memory" {
|
||||
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderPublishAndGet(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
event.UserID = 123
|
||||
|
||||
// Publish event
|
||||
if err := provider.Publish(context.Background(), event); err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
|
||||
// Get event
|
||||
retrieved, err := provider.Get(context.Background(), event.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get failed: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID != event.ID {
|
||||
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
|
||||
}
|
||||
if retrieved.UserID != 123 {
|
||||
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderGetNonExistent(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
_, err := provider.Get(context.Background(), "non-existent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting non-existent event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderUpdateStatus(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Update status to processing
|
||||
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
retrieved, _ := provider.Get(context.Background(), event.ID)
|
||||
if retrieved.Status != EventStatusProcessing {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
|
||||
}
|
||||
|
||||
// Update status to failed with error
|
||||
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
retrieved, _ = provider.Get(context.Background(), event.ID)
|
||||
if retrieved.Status != EventStatusFailed {
|
||||
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
|
||||
}
|
||||
if retrieved.Error != "test error" {
|
||||
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderList(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
// List all events
|
||||
events, err := provider.List(context.Background(), &EventFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 5 {
|
||||
t.Errorf("Expected 5 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderListWithFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events with different types
|
||||
event1 := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
event3 := NewEvent(EventSourceWebSocket, "chat.message")
|
||||
provider.Publish(context.Background(), event3)
|
||||
|
||||
// Filter by source
|
||||
source := EventSourceDatabase
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
Source: &source,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Errorf("Expected 2 events with database source, got %d", len(events))
|
||||
}
|
||||
|
||||
// Filter by status
|
||||
status := EventStatusPending
|
||||
events, err = provider.List(context.Background(), &EventFilter{
|
||||
Status: &status,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Errorf("Expected 3 events with pending status, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderListWithLimit(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish multiple events
|
||||
for i := 0; i < 10; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
// List with limit
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
Limit: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 5 {
|
||||
t.Errorf("Expected 5 events (limited), got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderDelete(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Delete event
|
||||
err := provider.Delete(context.Background(), event.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify deleted
|
||||
_, err = provider.Get(context.Background(), event.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting deleted event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderLRUEviction(t *testing.T) {
|
||||
// Create provider with small max events
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 3,
|
||||
})
|
||||
|
||||
// Publish 5 events
|
||||
events := make([]*Event, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
events[i] = NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), events[i])
|
||||
}
|
||||
|
||||
// First 2 events should be evicted
|
||||
_, err := provider.Get(context.Background(), events[0].ID)
|
||||
if err == nil {
|
||||
t.Error("Expected first event to be evicted")
|
||||
}
|
||||
|
||||
_, err = provider.Get(context.Background(), events[1].ID)
|
||||
if err == nil {
|
||||
t.Error("Expected second event to be evicted")
|
||||
}
|
||||
|
||||
// Last 3 events should still exist
|
||||
for i := 2; i < 5; i++ {
|
||||
_, err := provider.Get(context.Background(), events[i].ID)
|
||||
if err != nil {
|
||||
t.Errorf("Expected event %d to still exist", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderCleanup(t *testing.T) {
|
||||
// Create provider with short cleanup interval
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
MaxAge: 200 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Publish and complete an event
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
// Event should be cleaned up
|
||||
_, err := provider.Get(context.Background(), event.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected event to be cleaned up")
|
||||
}
|
||||
|
||||
provider.Close()
|
||||
}
|
||||
|
||||
func TestMemoryProviderStats(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
MaxEvents: 100,
|
||||
})
|
||||
|
||||
// Publish events
|
||||
for i := 0; i < 5; i++ {
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}
|
||||
|
||||
stats, err := provider.Stats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Stats failed: %v", err)
|
||||
}
|
||||
|
||||
if stats.ProviderType != "memory" {
|
||||
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||
}
|
||||
if stats.TotalEvents != 5 {
|
||||
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderClose(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Publish event
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
|
||||
// Close provider
|
||||
err := provider.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup goroutine should be stopped
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestMemoryProviderConcurrency(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Concurrent publish
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all events were stored
|
||||
events, _ := provider.List(context.Background(), &EventFilter{})
|
||||
if len(events) != 10 {
|
||||
t.Errorf("Expected 10 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderStream(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Stream is implemented for memory provider (in-process pub/sub)
|
||||
ch, err := provider.Stream(context.Background(), "test.*")
|
||||
if err != nil {
|
||||
t.Fatalf("Stream failed: %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Error("Expected non-nil channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events at different times
|
||||
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
event3 := NewEvent(EventSourceDatabase, "test.event")
|
||||
provider.Publish(context.Background(), event3)
|
||||
|
||||
// Filter by time range
|
||||
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
StartTime: &startTime,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
// Should get events 2 and 3
|
||||
if len(events) != 2 {
|
||||
t.Errorf("Expected 2 events after start time, got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
|
||||
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||
InstanceID: "test-instance",
|
||||
})
|
||||
|
||||
// Publish events with different instance IDs
|
||||
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||
event1.InstanceID = "instance-1"
|
||||
provider.Publish(context.Background(), event1)
|
||||
|
||||
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||
event2.InstanceID = "instance-2"
|
||||
provider.Publish(context.Background(), event2)
|
||||
|
||||
// Filter by instance ID
|
||||
events, err := provider.List(context.Background(), &EventFilter{
|
||||
InstanceID: "instance-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
|
||||
if len(events) != 1 {
|
||||
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
|
||||
}
|
||||
if events[0].InstanceID != "instance-1" {
|
||||
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
|
||||
}
|
||||
}
|
||||
565
pkg/eventbroker/provider_nats.go
Normal file
565
pkg/eventbroker/provider_nats.go
Normal file
@@ -0,0 +1,565 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// NATSProvider implements Provider interface using NATS JetStream
|
||||
// Features:
|
||||
// - Persistent event storage using JetStream
|
||||
// - Cross-instance pub/sub using NATS subjects
|
||||
// - Wildcard subscription support
|
||||
// - Durable consumers for event replay
|
||||
// - At-least-once delivery semantics
|
||||
type NATSProvider struct {
|
||||
nc *nats.Conn
|
||||
js jetstream.JetStream
|
||||
stream jetstream.Stream
|
||||
streamName string
|
||||
subjectPrefix string
|
||||
instanceID string
|
||||
maxAge time.Duration
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*natsSubscription
|
||||
|
||||
// Statistics
|
||||
stats NATSProviderStats
|
||||
|
||||
// Lifecycle
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// NATSProviderStats contains statistics for the NATS provider
|
||||
type NATSProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
ConsumerErrors atomic.Int64
|
||||
}
|
||||
|
||||
// natsSubscription represents a single NATS subscription
|
||||
type natsSubscription struct {
|
||||
pattern string
|
||||
consumer jetstream.Consumer
|
||||
ch chan *Event
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NATSProviderConfig configures the NATS provider
|
||||
type NATSProviderConfig struct {
|
||||
URL string
|
||||
StreamName string
|
||||
SubjectPrefix string // e.g., "events"
|
||||
InstanceID string
|
||||
MaxAge time.Duration // How long to keep events
|
||||
Storage string // "file" or "memory"
|
||||
}
|
||||
|
||||
// NewNATSProvider creates a new NATS event provider
|
||||
func NewNATSProvider(cfg NATSProviderConfig) (*NATSProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.URL == "" {
|
||||
cfg.URL = nats.DefaultURL
|
||||
}
|
||||
if cfg.StreamName == "" {
|
||||
cfg.StreamName = "RESOLVESPEC_EVENTS"
|
||||
}
|
||||
if cfg.SubjectPrefix == "" {
|
||||
cfg.SubjectPrefix = "events"
|
||||
}
|
||||
if cfg.MaxAge == 0 {
|
||||
cfg.MaxAge = 7 * 24 * time.Hour // 7 days
|
||||
}
|
||||
if cfg.Storage == "" {
|
||||
cfg.Storage = "file"
|
||||
}
|
||||
|
||||
// Connect to NATS
|
||||
nc, err := nats.Connect(cfg.URL,
|
||||
nats.Name("resolvespec-eventbroker-"+cfg.InstanceID),
|
||||
nats.Timeout(5*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
|
||||
// Create JetStream context
|
||||
js, err := jetstream.New(nc)
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
return nil, fmt.Errorf("failed to create JetStream context: %w", err)
|
||||
}
|
||||
|
||||
np := &NATSProvider{
|
||||
nc: nc,
|
||||
js: js,
|
||||
streamName: cfg.StreamName,
|
||||
subjectPrefix: cfg.SubjectPrefix,
|
||||
instanceID: cfg.InstanceID,
|
||||
maxAge: cfg.MaxAge,
|
||||
subscribers: make(map[string]*natsSubscription),
|
||||
}
|
||||
|
||||
np.isRunning.Store(true)
|
||||
|
||||
// Create or update stream
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Determine storage type
|
||||
var storage jetstream.StorageType
|
||||
if cfg.Storage == "memory" {
|
||||
storage = jetstream.MemoryStorage
|
||||
} else {
|
||||
storage = jetstream.FileStorage
|
||||
}
|
||||
|
||||
if err := np.ensureStream(ctx, storage); err != nil {
|
||||
nc.Close()
|
||||
return nil, fmt.Errorf("failed to create stream: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("NATS provider initialized (stream: %s, subject: %s.*, url: %s)",
|
||||
cfg.StreamName, cfg.SubjectPrefix, cfg.URL)
|
||||
|
||||
return np, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (np *NATSProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal event to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event: %w", err)
|
||||
}
|
||||
|
||||
// Publish to NATS subject
|
||||
// Subject format: events.{source}.{schema}.{entity}.{operation}
|
||||
subject := np.buildSubject(event)
|
||||
|
||||
msg := &nats.Msg{
|
||||
Subject: subject,
|
||||
Data: data,
|
||||
Header: nats.Header{
|
||||
"Event-ID": []string{event.ID},
|
||||
"Event-Type": []string{event.Type},
|
||||
"Event-Source": []string{string(event.Source)},
|
||||
"Event-Status": []string{string(event.Status)},
|
||||
"Instance-ID": []string{event.InstanceID},
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := np.js.PublishMsg(ctx, msg); err != nil {
|
||||
return fmt.Errorf("failed to publish event: %w", err)
|
||||
}
|
||||
|
||||
np.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
// Note: This is inefficient with JetStream - consider using a separate KV store for lookups
|
||||
func (np *NATSProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
// We need to scan messages which is not ideal
|
||||
// For production, consider using NATS KV store for fast lookups
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: "get-" + id,
|
||||
FilterSubject: np.subjectPrefix + ".>",
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
// Fetch messages in batches
|
||||
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch messages: %w", err)
|
||||
}
|
||||
|
||||
for msg := range msgs.Messages() {
|
||||
if msg.Headers().Get("Event-ID") == id {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
_ = msg.Nak()
|
||||
continue
|
||||
}
|
||||
_ = msg.Ack()
|
||||
|
||||
// Delete temporary consumer
|
||||
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
|
||||
|
||||
return &event, nil
|
||||
}
|
||||
_ = msg.Ack()
|
||||
}
|
||||
|
||||
// Delete temporary consumer
|
||||
_ = np.stream.DeleteConsumer(ctx, "get-"+id)
|
||||
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
func (np *NATSProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
var results []*Event
|
||||
|
||||
// Create temporary consumer
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: fmt.Sprintf("list-%d", time.Now().UnixNano()),
|
||||
FilterSubject: np.subjectPrefix + ".>",
|
||||
DeliverPolicy: jetstream.DeliverAllPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
defer func() { _ = np.stream.DeleteConsumer(ctx, consumer.CachedInfo().Name) }()
|
||||
|
||||
// Fetch messages in batches
|
||||
msgs, err := consumer.Fetch(1000, jetstream.FetchMaxWait(5*time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch messages: %w", err)
|
||||
}
|
||||
|
||||
for msg := range msgs.Messages() {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
_ = msg.Nak()
|
||||
continue
|
||||
}
|
||||
|
||||
if np.matchesFilter(&event, filter) {
|
||||
results = append(results, &event)
|
||||
}
|
||||
|
||||
_ = msg.Ack()
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
// Note: NATS streams are append-only, so we publish a status update event
|
||||
func (np *NATSProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
// Publish a status update message
|
||||
subject := fmt.Sprintf("%s.status.%s", np.subjectPrefix, id)
|
||||
|
||||
statusUpdate := map[string]interface{}{
|
||||
"event_id": id,
|
||||
"status": string(status),
|
||||
"error": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(statusUpdate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal status update: %w", err)
|
||||
}
|
||||
|
||||
if _, err := np.js.Publish(ctx, subject, data); err != nil {
|
||||
return fmt.Errorf("failed to publish status update: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
// Note: NATS streams don't support deletion - this just marks it in a separate subject
|
||||
func (np *NATSProvider) Delete(ctx context.Context, id string) error {
|
||||
subject := fmt.Sprintf("%s.deleted.%s", np.subjectPrefix, id)
|
||||
|
||||
deleteMsg := map[string]interface{}{
|
||||
"event_id": id,
|
||||
"deleted_at": time.Now(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(deleteMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal delete message: %w", err)
|
||||
}
|
||||
|
||||
if _, err := np.js.Publish(ctx, subject, data); err != nil {
|
||||
return fmt.Errorf("failed to publish delete message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
func (np *NATSProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
// Convert glob pattern to NATS subject pattern
|
||||
natsSubject := np.patternToSubject(pattern)
|
||||
|
||||
// Create durable consumer
|
||||
consumerName := fmt.Sprintf("consumer-%s-%d", np.instanceID, time.Now().UnixNano())
|
||||
consumer, err := np.stream.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{
|
||||
Name: consumerName,
|
||||
FilterSubject: natsSubject,
|
||||
DeliverPolicy: jetstream.DeliverNewPolicy,
|
||||
AckPolicy: jetstream.AckExplicitPolicy,
|
||||
AckWait: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &natsSubscription{
|
||||
pattern: pattern,
|
||||
consumer: consumer,
|
||||
ch: ch,
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
np.mu.Lock()
|
||||
np.subscribers[pattern] = sub
|
||||
np.stats.ActiveSubscribers.Add(1)
|
||||
np.mu.Unlock()
|
||||
|
||||
// Start consumer goroutine
|
||||
np.wg.Add(1)
|
||||
go np.consumeMessages(sub)
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers
|
||||
func (np *NATSProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := np.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
np.stats.EventsPublished.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (np *NATSProvider) Close() error {
|
||||
if !np.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
np.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
np.mu.Lock()
|
||||
for _, sub := range np.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
np.mu.Unlock()
|
||||
|
||||
// Wait for goroutines
|
||||
np.wg.Wait()
|
||||
|
||||
// Close NATS connection
|
||||
np.nc.Close()
|
||||
|
||||
logger.Info("NATS provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (np *NATSProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
streamInfo, err := np.stream.Info(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to get stream info: %v", err)
|
||||
}
|
||||
|
||||
stats := &ProviderStats{
|
||||
ProviderType: "nats",
|
||||
TotalEvents: np.stats.TotalEvents.Load(),
|
||||
EventsPublished: np.stats.EventsPublished.Load(),
|
||||
EventsConsumed: np.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(np.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"stream_name": np.streamName,
|
||||
"subject_prefix": np.subjectPrefix,
|
||||
"max_age": np.maxAge.String(),
|
||||
"consumer_errors": np.stats.ConsumerErrors.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
if streamInfo != nil {
|
||||
stats.ProviderSpecific["messages"] = streamInfo.State.Msgs
|
||||
stats.ProviderSpecific["bytes"] = streamInfo.State.Bytes
|
||||
stats.ProviderSpecific["consumers"] = streamInfo.State.Consumers
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ensureStream creates or updates the JetStream stream
|
||||
func (np *NATSProvider) ensureStream(ctx context.Context, storage jetstream.StorageType) error {
|
||||
streamConfig := jetstream.StreamConfig{
|
||||
Name: np.streamName,
|
||||
Subjects: []string{np.subjectPrefix + ".>"},
|
||||
MaxAge: np.maxAge,
|
||||
Storage: storage,
|
||||
Retention: jetstream.LimitsPolicy,
|
||||
Discard: jetstream.DiscardOld,
|
||||
}
|
||||
|
||||
stream, err := np.js.CreateStream(ctx, streamConfig)
|
||||
if err != nil {
|
||||
// Try to update if already exists
|
||||
stream, err = np.js.UpdateStream(ctx, streamConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create/update stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
np.stream = stream
|
||||
return nil
|
||||
}
|
||||
|
||||
// consumeMessages consumes messages from NATS for a subscription
|
||||
func (np *NATSProvider) consumeMessages(sub *natsSubscription) {
|
||||
defer np.wg.Done()
|
||||
defer close(sub.ch)
|
||||
defer func() {
|
||||
np.mu.Lock()
|
||||
delete(np.subscribers, sub.pattern)
|
||||
np.stats.ActiveSubscribers.Add(-1)
|
||||
np.mu.Unlock()
|
||||
}()
|
||||
|
||||
logger.Debug("Starting NATS consumer for pattern: %s", sub.pattern)
|
||||
|
||||
// Consume messages
|
||||
cc, err := sub.consumer.Consume(func(msg jetstream.Msg) {
|
||||
var event Event
|
||||
if err := json.Unmarshal(msg.Data(), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
_ = msg.Nak()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if event matches pattern (additional filtering)
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- &event:
|
||||
np.stats.EventsConsumed.Add(1)
|
||||
_ = msg.Ack()
|
||||
case <-sub.ctx.Done():
|
||||
_ = msg.Nak()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_ = msg.Ack()
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
np.stats.ConsumerErrors.Add(1)
|
||||
logger.Error("Failed to start consumer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for context cancellation
|
||||
<-sub.ctx.Done()
|
||||
|
||||
// Stop consuming
|
||||
cc.Stop()
|
||||
|
||||
logger.Debug("NATS consumer stopped for pattern: %s", sub.pattern)
|
||||
}
|
||||
|
||||
// buildSubject creates a NATS subject from an event
|
||||
// Format: events.{source}.{schema}.{entity}.{operation}
|
||||
func (np *NATSProvider) buildSubject(event *Event) string {
|
||||
return fmt.Sprintf("%s.%s.%s.%s.%s",
|
||||
np.subjectPrefix,
|
||||
event.Source,
|
||||
event.Schema,
|
||||
event.Entity,
|
||||
event.Operation,
|
||||
)
|
||||
}
|
||||
|
||||
// patternToSubject converts a glob pattern to NATS subject pattern
|
||||
// Examples:
|
||||
// - "*" -> "events.>"
|
||||
// - "public.users.*" -> "events.*.public.users.*"
|
||||
// - "public.*.*" -> "events.*.public.*.*"
|
||||
func (np *NATSProvider) patternToSubject(pattern string) string {
|
||||
if pattern == "*" {
|
||||
return np.subjectPrefix + ".>"
|
||||
}
|
||||
|
||||
// For specific patterns, we need to match the event type structure
|
||||
// Event type: schema.entity.operation
|
||||
// NATS subject: events.{source}.{schema}.{entity}.{operation}
|
||||
// We use wildcard for source since pattern doesn't include it
|
||||
return fmt.Sprintf("%s.*.%s", np.subjectPrefix, pattern)
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (np *NATSProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
541
pkg/eventbroker/provider_redis.go
Normal file
541
pkg/eventbroker/provider_redis.go
Normal file
@@ -0,0 +1,541 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// RedisProvider implements Provider interface using Redis Streams
|
||||
// Features:
|
||||
// - Persistent event storage using Redis Streams
|
||||
// - Cross-instance pub/sub using consumer groups
|
||||
// - Pattern-based subscription routing
|
||||
// - Automatic stream trimming to prevent unbounded growth
|
||||
type RedisProvider struct {
|
||||
client *redis.Client
|
||||
streamName string
|
||||
consumerGroup string
|
||||
consumerName string
|
||||
instanceID string
|
||||
maxLen int64
|
||||
|
||||
// Subscriptions
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*redisSubscription
|
||||
|
||||
// Statistics
|
||||
stats RedisProviderStats
|
||||
|
||||
// Lifecycle
|
||||
stopListeners chan struct{}
|
||||
wg sync.WaitGroup
|
||||
isRunning atomic.Bool
|
||||
}
|
||||
|
||||
// RedisProviderStats contains statistics for the Redis provider
|
||||
type RedisProviderStats struct {
|
||||
TotalEvents atomic.Int64
|
||||
EventsPublished atomic.Int64
|
||||
EventsConsumed atomic.Int64
|
||||
ActiveSubscribers atomic.Int32
|
||||
ConsumerErrors atomic.Int64
|
||||
}
|
||||
|
||||
// redisSubscription represents a single subscription
|
||||
type redisSubscription struct {
|
||||
pattern string
|
||||
ch chan *Event
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// RedisProviderConfig configures the Redis provider
|
||||
type RedisProviderConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int
|
||||
StreamName string
|
||||
ConsumerGroup string
|
||||
ConsumerName string
|
||||
InstanceID string
|
||||
MaxLen int64 // Maximum stream length (0 = unlimited)
|
||||
}
|
||||
|
||||
// NewRedisProvider creates a new Redis event provider
|
||||
func NewRedisProvider(cfg RedisProviderConfig) (*RedisProvider, error) {
|
||||
// Apply defaults
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = "localhost"
|
||||
}
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 6379
|
||||
}
|
||||
if cfg.StreamName == "" {
|
||||
cfg.StreamName = "resolvespec:events"
|
||||
}
|
||||
if cfg.ConsumerGroup == "" {
|
||||
cfg.ConsumerGroup = "resolvespec-workers"
|
||||
}
|
||||
if cfg.ConsumerName == "" {
|
||||
cfg.ConsumerName = cfg.InstanceID
|
||||
}
|
||||
if cfg.MaxLen == 0 {
|
||||
cfg.MaxLen = 10000 // Default max stream length
|
||||
}
|
||||
|
||||
// Create Redis client
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
PoolSize: 10,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
rp := &RedisProvider{
|
||||
client: client,
|
||||
streamName: cfg.StreamName,
|
||||
consumerGroup: cfg.ConsumerGroup,
|
||||
consumerName: cfg.ConsumerName,
|
||||
instanceID: cfg.InstanceID,
|
||||
maxLen: cfg.MaxLen,
|
||||
subscribers: make(map[string]*redisSubscription),
|
||||
stopListeners: make(chan struct{}),
|
||||
}
|
||||
|
||||
rp.isRunning.Store(true)
|
||||
|
||||
// Create consumer group if it doesn't exist
|
||||
if err := rp.ensureConsumerGroup(ctx); err != nil {
|
||||
logger.Warn("Failed to create consumer group: %v (may already exist)", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis provider initialized (stream: %s, consumer_group: %s, consumer: %s)",
|
||||
cfg.StreamName, cfg.ConsumerGroup, cfg.ConsumerName)
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
||||
// Store stores an event
|
||||
func (rp *RedisProvider) Store(ctx context.Context, event *Event) error {
|
||||
// Marshal event to JSON
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal event: %w", err)
|
||||
}
|
||||
|
||||
// Store in Redis Stream
|
||||
args := &redis.XAddArgs{
|
||||
Stream: rp.streamName,
|
||||
MaxLen: rp.maxLen,
|
||||
Approx: true, // Use approximate trimming for better performance
|
||||
Values: map[string]interface{}{
|
||||
"event": data,
|
||||
"id": event.ID,
|
||||
"type": event.Type,
|
||||
"source": string(event.Source),
|
||||
"status": string(event.Status),
|
||||
"instance_id": event.InstanceID,
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := rp.client.XAdd(ctx, args).Result(); err != nil {
|
||||
return fmt.Errorf("failed to add event to stream: %w", err)
|
||||
}
|
||||
|
||||
rp.stats.TotalEvents.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an event by ID
|
||||
// Note: This scans the stream which can be slow for large streams
|
||||
// Consider using a separate hash for fast lookups if needed
|
||||
func (rp *RedisProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||
// Scan stream for event with matching ID
|
||||
args := &redis.XReadArgs{
|
||||
Streams: []string{rp.streamName, "0"},
|
||||
Count: 1000, // Read in batches
|
||||
}
|
||||
|
||||
for {
|
||||
streams, err := rp.client.XRead(ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
|
||||
if len(streams) == 0 {
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
// Check if this is the event we're looking for
|
||||
if eventID, ok := message.Values["id"].(string); ok && eventID == id {
|
||||
// Parse event
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
|
||||
}
|
||||
return &event, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we've read messages, update start position for next iteration
|
||||
if len(stream.Messages) > 0 {
|
||||
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
|
||||
} else {
|
||||
// No more messages
|
||||
return nil, fmt.Errorf("event not found: %s", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// List lists events with optional filters
|
||||
// Note: This scans the entire stream which can be slow
|
||||
// Consider using time-based or ID-based ranges for better performance
|
||||
func (rp *RedisProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||
var results []*Event
|
||||
|
||||
// Read from stream
|
||||
args := &redis.XReadArgs{
|
||||
Streams: []string{rp.streamName, "0"},
|
||||
Count: 1000,
|
||||
}
|
||||
|
||||
for {
|
||||
streams, err := rp.client.XRead(ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
|
||||
if len(streams) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if rp.matchesFilter(&event, filter) {
|
||||
results = append(results, &event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update start position for next iteration
|
||||
if len(stream.Messages) > 0 {
|
||||
args.Streams[1] = stream.Messages[len(stream.Messages)-1].ID
|
||||
} else {
|
||||
// No more messages
|
||||
goto done
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
// Apply limit and offset
|
||||
if filter != nil {
|
||||
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||
results = results[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||
results = results[:filter.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the status of an event
|
||||
// Note: Redis Streams are append-only, so we need to store status updates separately
|
||||
// This uses a separate hash for status tracking
|
||||
func (rp *RedisProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"status": string(status),
|
||||
"updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if errorMsg != "" {
|
||||
fields["error"] = errorMsg
|
||||
}
|
||||
|
||||
if err := rp.client.HSet(ctx, statusKey, fields).Err(); err != nil {
|
||||
return fmt.Errorf("failed to update status: %w", err)
|
||||
}
|
||||
|
||||
// Set TTL on status key to prevent unbounded growth
|
||||
rp.client.Expire(ctx, statusKey, 7*24*time.Hour) // 7 days
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes an event by ID
|
||||
// Note: Redis Streams don't support deletion by field value
|
||||
// This marks the event as deleted in a separate set
|
||||
func (rp *RedisProvider) Delete(ctx context.Context, id string) error {
|
||||
deletedKey := fmt.Sprintf("%s:deleted", rp.streamName)
|
||||
|
||||
if err := rp.client.SAdd(ctx, deletedKey, id).Err(); err != nil {
|
||||
return fmt.Errorf("failed to mark event as deleted: %w", err)
|
||||
}
|
||||
|
||||
// Also delete the status hash if it exists
|
||||
statusKey := fmt.Sprintf("%s:status:%s", rp.streamName, id)
|
||||
rp.client.Del(ctx, statusKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stream returns a channel of events for real-time consumption
|
||||
// Uses Redis Streams consumer group for distributed processing
|
||||
func (rp *RedisProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||
ch := make(chan *Event, 100)
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &redisSubscription{
|
||||
pattern: pattern,
|
||||
ch: ch,
|
||||
ctx: subCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
rp.mu.Lock()
|
||||
rp.subscribers[pattern] = sub
|
||||
rp.stats.ActiveSubscribers.Add(1)
|
||||
rp.mu.Unlock()
|
||||
|
||||
// Start consumer goroutine
|
||||
rp.wg.Add(1)
|
||||
go rp.consumeStream(sub)
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Publish publishes an event to all subscribers (cross-instance)
|
||||
func (rp *RedisProvider) Publish(ctx context.Context, event *Event) error {
|
||||
// Store the event first
|
||||
if err := rp.Store(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rp.stats.EventsPublished.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases resources
|
||||
func (rp *RedisProvider) Close() error {
|
||||
if !rp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
rp.isRunning.Store(false)
|
||||
|
||||
// Cancel all subscriptions
|
||||
rp.mu.Lock()
|
||||
for _, sub := range rp.subscribers {
|
||||
sub.cancel()
|
||||
}
|
||||
rp.mu.Unlock()
|
||||
|
||||
// Stop listeners
|
||||
close(rp.stopListeners)
|
||||
|
||||
// Wait for goroutines
|
||||
rp.wg.Wait()
|
||||
|
||||
// Close Redis client
|
||||
if err := rp.client.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close Redis client: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis provider closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns provider statistics
|
||||
func (rp *RedisProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||
// Get stream info
|
||||
streamInfo, err := rp.client.XInfoStream(ctx, rp.streamName).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
logger.Warn("Failed to get stream info: %v", err)
|
||||
}
|
||||
|
||||
stats := &ProviderStats{
|
||||
ProviderType: "redis",
|
||||
TotalEvents: rp.stats.TotalEvents.Load(),
|
||||
EventsPublished: rp.stats.EventsPublished.Load(),
|
||||
EventsConsumed: rp.stats.EventsConsumed.Load(),
|
||||
ActiveSubscribers: int(rp.stats.ActiveSubscribers.Load()),
|
||||
ProviderSpecific: map[string]interface{}{
|
||||
"stream_name": rp.streamName,
|
||||
"consumer_group": rp.consumerGroup,
|
||||
"consumer_name": rp.consumerName,
|
||||
"max_len": rp.maxLen,
|
||||
"consumer_errors": rp.stats.ConsumerErrors.Load(),
|
||||
},
|
||||
}
|
||||
|
||||
if streamInfo != nil {
|
||||
stats.ProviderSpecific["stream_length"] = streamInfo.Length
|
||||
stats.ProviderSpecific["first_entry_id"] = streamInfo.FirstEntry.ID
|
||||
stats.ProviderSpecific["last_entry_id"] = streamInfo.LastEntry.ID
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// consumeStream consumes events from the Redis Stream for a subscription
|
||||
func (rp *RedisProvider) consumeStream(sub *redisSubscription) {
|
||||
defer rp.wg.Done()
|
||||
defer close(sub.ch)
|
||||
defer func() {
|
||||
rp.mu.Lock()
|
||||
delete(rp.subscribers, sub.pattern)
|
||||
rp.stats.ActiveSubscribers.Add(-1)
|
||||
rp.mu.Unlock()
|
||||
}()
|
||||
|
||||
logger.Debug("Starting stream consumer for pattern: %s", sub.pattern)
|
||||
|
||||
// Use consumer group for distributed processing
|
||||
for {
|
||||
select {
|
||||
case <-sub.ctx.Done():
|
||||
logger.Debug("Stream consumer stopped for pattern: %s", sub.pattern)
|
||||
return
|
||||
default:
|
||||
// Read from consumer group
|
||||
args := &redis.XReadGroupArgs{
|
||||
Group: rp.consumerGroup,
|
||||
Consumer: rp.consumerName,
|
||||
Streams: []string{rp.streamName, ">"},
|
||||
Count: 10,
|
||||
Block: 1 * time.Second,
|
||||
}
|
||||
|
||||
streams, err := rp.client.XReadGroup(sub.ctx, args).Result()
|
||||
if err == redis.Nil {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if sub.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
rp.stats.ConsumerErrors.Add(1)
|
||||
logger.Warn("Failed to read from consumer group: %v", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
if eventData, ok := message.Values["event"].(string); ok {
|
||||
var event Event
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
logger.Warn("Failed to unmarshal event: %v", err)
|
||||
// Acknowledge message anyway to prevent redelivery
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if event matches pattern
|
||||
if matchPattern(sub.pattern, event.Type) {
|
||||
select {
|
||||
case sub.ch <- &event:
|
||||
rp.stats.EventsConsumed.Add(1)
|
||||
// Acknowledge message
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
case <-sub.ctx.Done():
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Acknowledge message even if it doesn't match pattern
|
||||
rp.client.XAck(sub.ctx, rp.streamName, rp.consumerGroup, message.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureConsumerGroup creates the consumer group if it doesn't exist
|
||||
func (rp *RedisProvider) ensureConsumerGroup(ctx context.Context) error {
|
||||
// Try to create the stream and consumer group
|
||||
// MKSTREAM creates the stream if it doesn't exist
|
||||
err := rp.client.XGroupCreateMkStream(ctx, rp.streamName, rp.consumerGroup, "0").Err()
|
||||
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchesFilter checks if an event matches the filter criteria
|
||||
func (rp *RedisProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if filter.Source != nil && event.Source != *filter.Source {
|
||||
return false
|
||||
}
|
||||
if filter.Status != nil && event.Status != *filter.Status {
|
||||
return false
|
||||
}
|
||||
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||
return false
|
||||
}
|
||||
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||
return false
|
||||
}
|
||||
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||
return false
|
||||
}
|
||||
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||
return false
|
||||
}
|
||||
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||
return false
|
||||
}
|
||||
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||
return false
|
||||
}
|
||||
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
140
pkg/eventbroker/subscription.go
Normal file
140
pkg/eventbroker/subscription.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// SubscriptionID uniquely identifies a subscription
|
||||
type SubscriptionID string
|
||||
|
||||
// subscription represents a single subscription with its handler and pattern
|
||||
type subscription struct {
|
||||
id SubscriptionID
|
||||
pattern string
|
||||
handler EventHandler
|
||||
}
|
||||
|
||||
// subscriptionManager manages event subscriptions and pattern matching
|
||||
type subscriptionManager struct {
|
||||
mu sync.RWMutex
|
||||
subscriptions map[SubscriptionID]*subscription
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
// newSubscriptionManager creates a new subscription manager
|
||||
func newSubscriptionManager() *subscriptionManager {
|
||||
return &subscriptionManager{
|
||||
subscriptions: make(map[SubscriptionID]*subscription),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription
|
||||
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||
if pattern == "" {
|
||||
return "", fmt.Errorf("pattern cannot be empty")
|
||||
}
|
||||
if handler == nil {
|
||||
return "", fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
|
||||
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.subscriptions[id] = &subscription{
|
||||
id: id,
|
||||
pattern: pattern,
|
||||
handler: handler,
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription
|
||||
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if _, exists := sm.subscriptions[id]; !exists {
|
||||
return fmt.Errorf("subscription not found: %s", id)
|
||||
}
|
||||
|
||||
delete(sm.subscriptions, id)
|
||||
logger.Info("Unsubscribed: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMatching returns all handlers that match the event type
|
||||
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var handlers []EventHandler
|
||||
for _, sub := range sm.subscriptions {
|
||||
if matchPattern(sub.pattern, eventType) {
|
||||
handlers = append(handlers, sub.handler)
|
||||
}
|
||||
}
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
// Count returns the number of active subscriptions
|
||||
func (sm *subscriptionManager) Count() int {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return len(sm.subscriptions)
|
||||
}
|
||||
|
||||
// Clear removes all subscriptions
|
||||
func (sm *subscriptionManager) Clear() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.subscriptions = make(map[SubscriptionID]*subscription)
|
||||
logger.Info("Cleared all subscriptions")
|
||||
}
|
||||
|
||||
// matchPattern implements glob-style pattern matching for event types
|
||||
// Patterns:
|
||||
// - "*" matches any single segment
|
||||
// - "a.b.c" matches exactly "a.b.c"
|
||||
// - "a.*.c" matches "a.anything.c"
|
||||
// - "a.b.*" matches any operation on a.b
|
||||
// - "*" matches everything
|
||||
//
|
||||
// Event type format: schema.entity.operation (e.g., "public.users.create")
|
||||
func matchPattern(pattern, eventType string) bool {
|
||||
// Wildcard matches everything
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Exact match
|
||||
if pattern == eventType {
|
||||
return true
|
||||
}
|
||||
|
||||
// Split pattern and event type by dots
|
||||
patternParts := strings.Split(pattern, ".")
|
||||
eventParts := strings.Split(eventType, ".")
|
||||
|
||||
// Different number of parts can only match if pattern has wildcards
|
||||
if len(patternParts) != len(eventParts) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Match each part
|
||||
for i := range patternParts {
|
||||
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
270
pkg/eventbroker/subscription_test.go
Normal file
270
pkg/eventbroker/subscription_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
eventType string
|
||||
expected bool
|
||||
}{
|
||||
// Exact matches
|
||||
{"public.users.create", "public.users.create", true},
|
||||
{"public.users.create", "public.users.update", false},
|
||||
|
||||
// Wildcard matches
|
||||
{"*", "public.users.create", true},
|
||||
{"*", "anything", true},
|
||||
{"public.*", "public.users", true},
|
||||
{"public.*", "public.users.create", false}, // Different number of parts
|
||||
{"public.*", "admin.users", false},
|
||||
{"*.users.create", "public.users.create", true},
|
||||
{"*.users.create", "admin.users.create", true},
|
||||
{"*.users.create", "public.roles.create", false},
|
||||
{"public.*.create", "public.users.create", true},
|
||||
{"public.*.create", "public.roles.create", true},
|
||||
{"public.*.create", "public.users.update", false},
|
||||
|
||||
// Multiple wildcards
|
||||
{"*.*", "public.users", true},
|
||||
{"*.*", "public.users.create", false}, // Different number of parts
|
||||
{"*.*.create", "public.users.create", true},
|
||||
{"*.*.create", "admin.roles.create", true},
|
||||
{"*.*.create", "public.users.update", false},
|
||||
|
||||
// Edge cases
|
||||
{"", "", true},
|
||||
{"", "something", false},
|
||||
{"something", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.eventType)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
|
||||
tt.pattern, tt.eventType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManager(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// Create test handler
|
||||
called := false
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test Subscribe
|
||||
id, err := manager.Subscribe("public.users.*", handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("Expected non-empty subscription ID")
|
||||
}
|
||||
|
||||
// Test GetMatching
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 1 {
|
||||
t.Fatalf("Expected 1 handler, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Test handler execution
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
if err := handlers[0].Handle(context.Background(), event); err != nil {
|
||||
t.Fatalf("Handler execution failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
|
||||
// Test Count
|
||||
if manager.Count() != 1 {
|
||||
t.Errorf("Expected count 1, got %d", manager.Count())
|
||||
}
|
||||
|
||||
// Test Unsubscribe
|
||||
if err := manager.Unsubscribe(id); err != nil {
|
||||
t.Fatalf("Unsubscribe failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify unsubscribed
|
||||
handlers = manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 0 {
|
||||
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
|
||||
}
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
called1 := false
|
||||
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called1 = true
|
||||
return nil
|
||||
})
|
||||
|
||||
called2 := false
|
||||
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called2 = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe multiple handlers
|
||||
id1, _ := manager.Subscribe("public.users.*", handler1)
|
||||
id2, _ := manager.Subscribe("*.users.*", handler2)
|
||||
|
||||
// Both should match
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 2 {
|
||||
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
for _, h := range handlers {
|
||||
h.Handle(context.Background(), event)
|
||||
}
|
||||
|
||||
if !called1 || !called2 {
|
||||
t.Error("Expected both handlers to be called")
|
||||
}
|
||||
|
||||
// Unsubscribe one
|
||||
manager.Unsubscribe(id1)
|
||||
handlers = manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 1 {
|
||||
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Unsubscribe remaining
|
||||
manager.Unsubscribe(id2)
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerConcurrency(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe and unsubscribe concurrently
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
id, _ := manager.Subscribe("test.*", handler)
|
||||
manager.GetMatching("test.event")
|
||||
manager.Unsubscribe(id)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Should have no subscriptions left
|
||||
if manager.Count() != 0 {
|
||||
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// Try to unsubscribe a non-existent ID
|
||||
err := manager.Unsubscribe("non-existent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when unsubscribing non-existent ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionIDGeneration(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Subscribe multiple times and ensure unique IDs
|
||||
ids := make(map[SubscriptionID]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id, _ := manager.Subscribe("test.*", handler)
|
||||
if ids[id] {
|
||||
t.Fatalf("Duplicate subscription ID: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventHandlerFunc(t *testing.T) {
|
||||
called := false
|
||||
var receivedEvent *Event
|
||||
|
||||
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
called = true
|
||||
receivedEvent = event
|
||||
return nil
|
||||
})
|
||||
|
||||
event := NewEvent(EventSourceDatabase, "test.event")
|
||||
err := handler.Handle(context.Background(), event)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
if receivedEvent != event {
|
||||
t.Error("Expected to receive the same event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionManagerPatternPriority(t *testing.T) {
|
||||
manager := newSubscriptionManager()
|
||||
|
||||
// More specific patterns should still match
|
||||
specificCalled := false
|
||||
genericCalled := false
|
||||
|
||||
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
specificCalled = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||
genericCalled = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
handlers := manager.GetMatching("public.users.create")
|
||||
if len(handlers) != 2 {
|
||||
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
|
||||
}
|
||||
|
||||
// Execute all handlers
|
||||
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||
for _, h := range handlers {
|
||||
h.Handle(context.Background(), event)
|
||||
}
|
||||
|
||||
if !specificCalled || !genericCalled {
|
||||
t.Error("Expected both specific and generic handlers to be called")
|
||||
}
|
||||
}
|
||||
141
pkg/eventbroker/worker_pool.go
Normal file
141
pkg/eventbroker/worker_pool.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package eventbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// workerPool manages a pool of workers for async event processing
|
||||
type workerPool struct {
|
||||
workerCount int
|
||||
bufferSize int
|
||||
eventQueue chan *Event
|
||||
processor func(context.Context, *Event) error
|
||||
|
||||
activeWorkers atomic.Int32
|
||||
isRunning atomic.Bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// newWorkerPool creates a new worker pool
|
||||
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
|
||||
return &workerPool{
|
||||
workerCount: workerCount,
|
||||
bufferSize: bufferSize,
|
||||
eventQueue: make(chan *Event, bufferSize),
|
||||
processor: processor,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the worker pool
|
||||
func (wp *workerPool) Start() {
|
||||
if wp.isRunning.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
wp.isRunning.Store(true)
|
||||
|
||||
// Start workers
|
||||
for i := 0; i < wp.workerCount; i++ {
|
||||
wp.wg.Add(1)
|
||||
go wp.worker(i)
|
||||
}
|
||||
|
||||
logger.Info("Worker pool started with %d workers", wp.workerCount)
|
||||
}
|
||||
|
||||
// Stop stops the worker pool gracefully
|
||||
func (wp *workerPool) Stop(ctx context.Context) error {
|
||||
if !wp.isRunning.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
wp.isRunning.Store(false)
|
||||
|
||||
// Close event queue to signal workers
|
||||
close(wp.eventQueue)
|
||||
|
||||
// Wait for workers to finish with context timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wp.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Info("Worker pool stopped gracefully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Worker pool stop timed out, some events may be lost")
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Submit submits an event to the queue
|
||||
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
|
||||
if !wp.isRunning.Load() {
|
||||
return ErrWorkerPoolStopped
|
||||
}
|
||||
|
||||
select {
|
||||
case wp.eventQueue <- event:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return ErrQueueFull
|
||||
}
|
||||
}
|
||||
|
||||
// worker is a worker goroutine that processes events from the queue
|
||||
func (wp *workerPool) worker(id int) {
|
||||
defer wp.wg.Done()
|
||||
|
||||
logger.Debug("Worker %d started", id)
|
||||
|
||||
for event := range wp.eventQueue {
|
||||
wp.activeWorkers.Add(1)
|
||||
|
||||
// Process event with background context (detached from original request)
|
||||
ctx := context.Background()
|
||||
if err := wp.processor(ctx, event); err != nil {
|
||||
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
|
||||
}
|
||||
|
||||
wp.activeWorkers.Add(-1)
|
||||
}
|
||||
|
||||
logger.Debug("Worker %d stopped", id)
|
||||
}
|
||||
|
||||
// QueueSize returns the current queue size
|
||||
func (wp *workerPool) QueueSize() int {
|
||||
return len(wp.eventQueue)
|
||||
}
|
||||
|
||||
// ActiveWorkers returns the number of currently active workers
|
||||
func (wp *workerPool) ActiveWorkers() int {
|
||||
return int(wp.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
|
||||
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
|
||||
)
|
||||
|
||||
// BrokerError represents an error from the event broker
|
||||
type BrokerError struct {
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *BrokerError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
@@ -70,6 +70,10 @@ func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
|
||||
@@ -75,6 +75,28 @@ func CloseErrorTracking() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractContext attempts to find a context.Context in the given arguments.
|
||||
// It returns the found context (or context.Background() if not found) and
|
||||
// the remaining arguments without the context.
|
||||
func extractContext(args ...interface{}) (ctx context.Context, filteredArgs []interface{}) {
|
||||
ctx = context.Background()
|
||||
var newArgs []interface{}
|
||||
found := false
|
||||
|
||||
for _, arg := range args {
|
||||
if c, ok := arg.(context.Context); ok {
|
||||
if !found {
|
||||
ctx = c
|
||||
found = true
|
||||
}
|
||||
// Ignore any additional context.Context arguments after the first one.
|
||||
continue
|
||||
}
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
return ctx, newArgs
|
||||
}
|
||||
|
||||
func Info(template string, args ...interface{}) {
|
||||
if Logger == nil {
|
||||
log.Printf(template, args...)
|
||||
@@ -84,7 +106,8 @@ func Info(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
func Warn(template string, args ...interface{}) {
|
||||
message := fmt.Sprintf(template, args...)
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
@@ -93,14 +116,15 @@ func Warn(template string, args ...interface{}) {
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Error(template string, args ...interface{}) {
|
||||
message := fmt.Sprintf(template, args...)
|
||||
ctx, remainingArgs := extractContext(args...)
|
||||
message := fmt.Sprintf(template, remainingArgs...)
|
||||
if Logger == nil {
|
||||
log.Printf("%s", message)
|
||||
} else {
|
||||
@@ -109,7 +133,7 @@ func Error(template string, args ...interface{}) {
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
|
||||
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
@@ -124,34 +148,41 @@ func Debug(template string, args ...interface{}) {
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanicCallback(location string, cb func(err any)) {
|
||||
if err := recover(); err != nil {
|
||||
callstack := debug.Stack()
|
||||
// Returns a function that should be deferred to catch panics
|
||||
// Example usage: defer CatchPanicCallback("MyFunction", func(err any) { /* cleanup */ })()
|
||||
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) func() {
|
||||
ctx, _ := extractContext(args...)
|
||||
return func() {
|
||||
if err := recover(); err != nil {
|
||||
callstack := debug.Stack()
|
||||
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err)
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
|
||||
"location": location,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
|
||||
"location": location,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
}
|
||||
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
// Returns a function that should be deferred to catch panics
|
||||
// Example usage: defer CatchPanic("MyFunction")()
|
||||
func CatchPanic(location string, args ...interface{}) func() {
|
||||
return CatchPanicCallback(location, nil, args...)
|
||||
}
|
||||
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
@@ -163,13 +194,14 @@ func CatchPanic(location string) {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
func HandlePanic(methodName string, r any, args ...interface{}) error {
|
||||
ctx, _ := extractContext(args...)
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
|
||||
|
||||
// Send to error tracker
|
||||
if errorTracker != nil {
|
||||
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
|
||||
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
|
||||
"method": methodName,
|
||||
"process_id": os.Getpid(),
|
||||
})
|
||||
|
||||
@@ -30,6 +30,18 @@ type Provider interface {
|
||||
// UpdateCacheSize updates the cache size metric
|
||||
UpdateCacheSize(provider string, size int64)
|
||||
|
||||
// RecordEventPublished records an event publication
|
||||
RecordEventPublished(source, eventType string)
|
||||
|
||||
// RecordEventProcessed records an event processing with its status
|
||||
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||
|
||||
// UpdateEventQueueSize updates the event queue size metric
|
||||
UpdateEventQueueSize(size int64)
|
||||
|
||||
// RecordPanic records a panic event
|
||||
RecordPanic(methodName string)
|
||||
|
||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||
Handler() http.Handler
|
||||
}
|
||||
@@ -59,9 +71,14 @@ func (n *NoOpProvider) IncRequestsInFlight()
|
||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
||||
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
}
|
||||
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||
func (n *NoOpProvider) RecordPanic(methodName string) {}
|
||||
func (n *NoOpProvider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
@@ -20,6 +20,7 @@ type PrometheusProvider struct {
|
||||
cacheHits *prometheus.CounterVec
|
||||
cacheMisses *prometheus.CounterVec
|
||||
cacheSize *prometheus.GaugeVec
|
||||
panicsTotal *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||
@@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider {
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
panicsTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "panics_total",
|
||||
Help: "Total number of panics",
|
||||
},
|
||||
[]string{"method"},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||
}
|
||||
|
||||
// RecordPanic implements the Provider interface
|
||||
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||
}
|
||||
|
||||
// Handler implements Provider interface
|
||||
func (p *PrometheusProvider) Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
|
||||
33
pkg/middleware/panic.go
Normal file
33
pkg/middleware/panic.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
const panicMiddlewareMethodName = "PanicMiddleware"
|
||||
|
||||
// PanicRecovery is a middleware that recovers from panics, logs the error,
|
||||
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
|
||||
func PanicRecovery(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
// Record the panic metric
|
||||
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
|
||||
|
||||
// Log the panic and send to error tracker
|
||||
// We pass the request context so the error tracker can potentially
|
||||
// link the panic to the request trace.
|
||||
ctx := r.Context()
|
||||
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
|
||||
|
||||
// Respond with a 500 error
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
86
pkg/middleware/panic_test.go
Normal file
86
pkg/middleware/panic_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
|
||||
type mockMetricsProvider struct {
|
||||
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
|
||||
panicRecorded bool
|
||||
methodName string
|
||||
}
|
||||
|
||||
func (m *mockMetricsProvider) RecordPanic(methodName string) {
|
||||
m.panicRecorded = true
|
||||
m.methodName = methodName
|
||||
}
|
||||
|
||||
func TestPanicRecovery(t *testing.T) {
|
||||
// Initialize a mock logger to avoid actual logging output during tests
|
||||
logger.Init(true)
|
||||
|
||||
// Setup mock metrics provider
|
||||
mockProvider := &mockMetricsProvider{}
|
||||
originalProvider := metrics.GetProvider()
|
||||
metrics.SetProvider(mockProvider)
|
||||
defer metrics.SetProvider(originalProvider) // Restore original provider after test
|
||||
|
||||
// 1. Test case: A handler that panics
|
||||
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
mockProvider.methodName = ""
|
||||
|
||||
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("something went terribly wrong")
|
||||
})
|
||||
|
||||
// Create the middleware wrapping the panicking handler
|
||||
testHandler := PanicRecovery(panicHandler)
|
||||
|
||||
// Create a test request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Serve the request
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
|
||||
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
|
||||
|
||||
// Assert that the metric was recorded
|
||||
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
|
||||
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
|
||||
})
|
||||
|
||||
// 2. Test case: A handler that does NOT panic
|
||||
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
|
||||
// Reset mock state for this sub-test
|
||||
mockProvider.panicRecorded = false
|
||||
|
||||
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
testHandler := PanicRecovery(successHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Assertions
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
|
||||
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
|
||||
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
|
||||
})
|
||||
}
|
||||
@@ -6,15 +6,37 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ModelRules defines the permissions and security settings for a model
|
||||
type ModelRules struct {
|
||||
CanRead bool // Whether the model can be read (GET operations)
|
||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanCreate bool // Whether the model can be created (POST operations)
|
||||
CanDelete bool // Whether the model can be deleted (DELETE operations)
|
||||
SecurityDisabled bool // Whether security checks are disabled for this model
|
||||
}
|
||||
|
||||
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
|
||||
func DefaultModelRules() ModelRules {
|
||||
return ModelRules{
|
||||
CanRead: true,
|
||||
CanUpdate: true,
|
||||
CanCreate: true,
|
||||
CanDelete: true,
|
||||
SecurityDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultModelRegistry implements ModelRegistry interface
|
||||
type DefaultModelRegistry struct {
|
||||
models map[string]interface{}
|
||||
rules map[string]ModelRules
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Global default registry instance
|
||||
var defaultRegistry = &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
rules: make(map[string]ModelRules),
|
||||
}
|
||||
|
||||
// Global list of registries (searched in order)
|
||||
@@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex
|
||||
func NewModelRegistry() *DefaultModelRegistry {
|
||||
return &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
rules: make(map[string]ModelRules),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
||||
}
|
||||
|
||||
r.models[name] = model
|
||||
// Initialize with default rules if not already set
|
||||
if _, exists := r.rules[name]; !exists {
|
||||
r.rules[name] = DefaultModelRules()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
|
||||
return r.GetModel(entity)
|
||||
}
|
||||
|
||||
// SetModelRules sets the rules for a specific model
|
||||
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
// Check if model exists
|
||||
if _, exists := r.models[name]; !exists {
|
||||
return fmt.Errorf("model %s not found", name)
|
||||
}
|
||||
|
||||
r.rules[name] = rules
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelRules retrieves the rules for a specific model
|
||||
// Returns default rules if model exists but rules are not set
|
||||
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Check if model exists
|
||||
if _, exists := r.models[name]; !exists {
|
||||
return ModelRules{}, fmt.Errorf("model %s not found", name)
|
||||
}
|
||||
|
||||
// Return rules if set, otherwise return default rules
|
||||
if rules, exists := r.rules[name]; exists {
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
return DefaultModelRules(), nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model with specific rules
|
||||
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
|
||||
// First register the model
|
||||
if err := r.RegisterModel(name, model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then set the rules (we need to lock again for rules)
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.rules[name] = rules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global convenience functions using the default registry
|
||||
|
||||
// RegisterModel registers a model with the default global registry
|
||||
@@ -190,3 +265,34 @@ func GetModels() []interface{} {
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// SetModelRules sets the rules for a specific model in the default registry
|
||||
func SetModelRules(name string, rules ModelRules) error {
|
||||
return defaultRegistry.SetModelRules(name, rules)
|
||||
}
|
||||
|
||||
// GetModelRules retrieves the rules for a specific model from the default registry
|
||||
func GetModelRules(name string) (ModelRules, error) {
|
||||
return defaultRegistry.GetModelRules(name)
|
||||
}
|
||||
|
||||
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
|
||||
// Returns the first match found
|
||||
func GetModelRulesByName(name string) (ModelRules, error) {
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
for _, registry := range registries {
|
||||
if _, err := registry.GetModel(name); err == nil {
|
||||
// Model found in this registry, get its rules
|
||||
return registry.GetModelRules(name)
|
||||
}
|
||||
}
|
||||
|
||||
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model with specific rules in the default registry
|
||||
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
|
||||
return defaultRegistry.RegisterModelWithRules(name, model, rules)
|
||||
}
|
||||
|
||||
724
pkg/mqttspec/README.md
Normal file
724
pkg/mqttspec/README.md
Normal file
@@ -0,0 +1,724 @@
|
||||
# MQTTSpec - MQTT-based Database Query Framework
|
||||
|
||||
MQTTSpec is an MQTT-based database query framework that enables real-time database operations and subscriptions via MQTT protocol. It mirrors the functionality of WebSocketSpec but uses MQTT as the transport layer, making it ideal for IoT applications, mobile apps with unreliable networks, and distributed systems requiring QoS guarantees.
|
||||
|
||||
## Features
|
||||
|
||||
- **Dual Broker Support**: Embedded broker (Mochi MQTT) or external broker connection (Paho MQTT)
|
||||
- **QoS 1 (At-least-once delivery)**: Reliable message delivery for all operations
|
||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||
- **Database Agnostic**: GORM and Bun ORM support
|
||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||
- **Thread-safe**: Proper concurrency handling throughout
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/bitechdev/ResolveSpec/pkg/mqttspec
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Embedded Broker (Default)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/mqttspec"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Connect to database
|
||||
db, _ := gorm.Open(postgres.Open("postgres://..."), &gorm.Config{})
|
||||
db.AutoMigrate(&User{})
|
||||
|
||||
// Create MQTT handler with embedded broker
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Register models
|
||||
handler.Registry().RegisterModel("public.users", &User{})
|
||||
|
||||
// Start handler (starts embedded broker on localhost:1883)
|
||||
if err := handler.Start(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Handler is now listening for MQTT messages
|
||||
select {} // Keep running
|
||||
}
|
||||
```
|
||||
|
||||
### External Broker
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithExternalBroker(mqttspec.ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://mqtt.example.com:1883",
|
||||
ClientID: "mqttspec-server",
|
||||
Username: "admin",
|
||||
Password: "secret",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Port (Embedded Broker)
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithEmbeddedBroker(mqttspec.BrokerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 1884,
|
||||
}),
|
||||
)
|
||||
```
|
||||
|
||||
## Topic Structure
|
||||
|
||||
MQTTSpec uses a client-based topic hierarchy:
|
||||
|
||||
```
|
||||
spec/{client_id}/request # Client publishes requests
|
||||
spec/{client_id}/response # Server publishes responses
|
||||
spec/{client_id}/notify/{sub_id} # Server publishes notifications
|
||||
```
|
||||
|
||||
### Wildcard Subscriptions
|
||||
|
||||
- **Server**: `spec/+/request` (receives all client requests)
|
||||
- **Client**: `spec/{client_id}/response` + `spec/{client_id}/notify/+`
|
||||
|
||||
## Message Protocol
|
||||
|
||||
MQTTSpec uses the same JSON message structure as WebSocketSpec and ResolveSpec for consistency.
|
||||
|
||||
### Request Message
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-123",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
],
|
||||
"sort": [{"column": "created_at", "direction": "desc"}],
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response Message
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-123",
|
||||
"type": "response",
|
||||
"success": true,
|
||||
"data": [
|
||||
{"id": 1, "name": "John Doe", "email": "john@example.com", "status": "active"},
|
||||
{"id": 2, "name": "Jane Smith", "email": "jane@example.com", "status": "active"}
|
||||
],
|
||||
"metadata": {
|
||||
"total": 50,
|
||||
"count": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Message
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "notification",
|
||||
"operation": "create",
|
||||
"subscription_id": "sub-xyz",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 3,
|
||||
"name": "New User",
|
||||
"email": "new@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## CRUD Operations
|
||||
|
||||
### Read (Single Record)
|
||||
|
||||
**MQTT Client Publishes to**: `spec/{client_id}/request`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-1",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {"id": 1}
|
||||
}
|
||||
```
|
||||
|
||||
**Server Publishes Response to**: `spec/{client_id}/response`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-1",
|
||||
"success": true,
|
||||
"data": {"id": 1, "name": "John Doe", "email": "john@example.com"}
|
||||
}
|
||||
```
|
||||
|
||||
### Read (Multiple Records with Filtering)
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-2",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
],
|
||||
"sort": [{"column": "name", "direction": "asc"}],
|
||||
"limit": 20,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-3",
|
||||
"type": "request",
|
||||
"operation": "create",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"name": "Alice Brown",
|
||||
"email": "alice@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Update
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-4",
|
||||
"type": "request",
|
||||
"operation": "update",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 1,
|
||||
"status": "inactive"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Delete
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-5",
|
||||
"type": "request",
|
||||
"operation": "delete",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {"id": 1}
|
||||
}
|
||||
```
|
||||
|
||||
## Real-time Subscriptions
|
||||
|
||||
### Subscribe to Entity Changes
|
||||
|
||||
**Client Publishes to**: `spec/{client_id}/request`
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-6",
|
||||
"type": "subscription",
|
||||
"operation": "subscribe",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": {
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Server Response** (published to `spec/{client_id}/response`):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-6",
|
||||
"success": true,
|
||||
"data": {
|
||||
"subscription_id": "sub-abc123",
|
||||
"notify_topic": "spec/{client_id}/notify/sub-abc123"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Client Then Subscribes** to MQTT topic: `spec/{client_id}/notify/sub-abc123`
|
||||
|
||||
### Receiving Notifications
|
||||
|
||||
When any client creates/updates/deletes a user matching the subscription filters, the subscriber receives:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "notification",
|
||||
"operation": "create",
|
||||
"subscription_id": "sub-abc123",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"data": {
|
||||
"id": 10,
|
||||
"name": "New User",
|
||||
"email": "newuser@example.com",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Unsubscribe
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg-7",
|
||||
"type": "subscription",
|
||||
"operation": "unsubscribe",
|
||||
"data": {
|
||||
"subscription_id": "sub-abc123"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
|
||||
### Hook Types
|
||||
|
||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||
- `BeforeRead` / `AfterRead` - Read operations
|
||||
- `BeforeCreate` / `AfterCreate` - Create operations
|
||||
- `BeforeUpdate` / `AfterUpdate` - Update operations
|
||||
- `BeforeDelete` / `AfterDelete` - Delete operations
|
||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||
|
||||
### Authentication Example (JWT)
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.BeforeConnect, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
|
||||
// MQTT username contains JWT token
|
||||
token := client.Username
|
||||
claims, err := jwt.Validate(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
||||
// Store user info in client metadata for later use
|
||||
client.SetMetadata("user_id", claims.UserID)
|
||||
client.SetMetadata("tenant_id", claims.TenantID)
|
||||
client.SetMetadata("roles", claims.Roles)
|
||||
|
||||
logger.Info("Client authenticated: user_id=%d, tenant=%s", claims.UserID, claims.TenantID)
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Multi-tenancy Example
|
||||
|
||||
```go
|
||||
// Auto-inject tenant filter for all read operations
|
||||
handler.Hooks().Register(mqttspec.BeforeRead, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
tenantID, _ := client.GetMetadata("tenant_id")
|
||||
|
||||
// Add tenant filter to ensure users only see their own data
|
||||
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
|
||||
Column: "tenant_id",
|
||||
Operator: "eq",
|
||||
Value: tenantID,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Auto-set tenant_id for all create operations
|
||||
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
tenantID, _ := client.GetMetadata("tenant_id")
|
||||
|
||||
// Inject tenant_id into new records
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["tenant_id"] = tenantID
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Role-based Access Control (RBAC)
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.BeforeDelete, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
roles, _ := client.GetMetadata("roles")
|
||||
|
||||
roleList := roles.([]string)
|
||||
hasAdminRole := false
|
||||
for _, role := range roleList {
|
||||
if role == "admin" {
|
||||
hasAdminRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAdminRole {
|
||||
return fmt.Errorf("permission denied: delete requires admin role")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Audit Logging Example
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(mqttspec.AfterCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
userID, _ := client.GetMetadata("user_id")
|
||||
|
||||
logger.Info("Audit: user %d created %s.%s record: %+v",
|
||||
userID, ctx.Schema, ctx.Entity, ctx.Result)
|
||||
|
||||
// Could also write to audit log table
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
## Client Examples
|
||||
|
||||
### JavaScript (MQTT.js)
|
||||
|
||||
```javascript
|
||||
const mqtt = require('mqtt');
|
||||
|
||||
// Connect to MQTT broker
|
||||
const client = mqtt.connect('mqtt://localhost:1883', {
|
||||
clientId: 'client-abc123',
|
||||
username: 'your-jwt-token',
|
||||
password: '', // JWT in username, password can be empty
|
||||
});
|
||||
|
||||
client.on('connect', () => {
|
||||
console.log('Connected to MQTT broker');
|
||||
|
||||
// Subscribe to responses
|
||||
client.subscribe('spec/client-abc123/response');
|
||||
|
||||
// Read users
|
||||
const readMsg = {
|
||||
id: 'msg-1',
|
||||
type: 'request',
|
||||
operation: 'read',
|
||||
schema: 'public',
|
||||
entity: 'users',
|
||||
options: {
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
client.publish('spec/client-abc123/request', JSON.stringify(readMsg));
|
||||
});
|
||||
|
||||
client.on('message', (topic, payload) => {
|
||||
const message = JSON.parse(payload.toString());
|
||||
console.log('Received:', message);
|
||||
|
||||
if (message.type === 'response') {
|
||||
console.log('Response data:', message.data);
|
||||
} else if (message.type === 'notification') {
|
||||
console.log('Notification:', message.operation, message.data);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
### Python (paho-mqtt)
|
||||
|
||||
```python
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
|
||||
client_id = 'client-python-123'
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
print(f"Connected with result code {rc}")
|
||||
|
||||
# Subscribe to responses
|
||||
client.subscribe(f"spec/{client_id}/response")
|
||||
|
||||
# Create a user
|
||||
create_msg = {
|
||||
'id': 'msg-create-1',
|
||||
'type': 'request',
|
||||
'operation': 'create',
|
||||
'schema': 'public',
|
||||
'entity': 'users',
|
||||
'data': {
|
||||
'name': 'Python User',
|
||||
'email': 'python@example.com',
|
||||
'status': 'active'
|
||||
}
|
||||
}
|
||||
|
||||
client.publish(f"spec/{client_id}/request", json.dumps(create_msg))
|
||||
|
||||
def on_message(client, userdata, msg):
|
||||
message = json.loads(msg.payload.decode())
|
||||
print(f"Received on {msg.topic}: {message}")
|
||||
|
||||
client = mqtt.Client(client_id=client_id)
|
||||
client.username_pw_set('your-jwt-token', '')
|
||||
client.on_connect = on_connect
|
||||
client.on_message = on_message
|
||||
|
||||
client.connect('localhost', 1883, 60)
|
||||
client.loop_forever()
|
||||
```
|
||||
|
||||
### Go (paho.mqtt.golang)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
)
|
||||
|
||||
func main() {
|
||||
clientID := "client-go-123"
|
||||
|
||||
opts := mqtt.NewClientOptions()
|
||||
opts.AddBroker("tcp://localhost:1883")
|
||||
opts.SetClientID(clientID)
|
||||
opts.SetUsername("your-jwt-token")
|
||||
opts.SetPassword("")
|
||||
|
||||
opts.SetDefaultPublishHandler(func(client mqtt.Client, msg mqtt.Message) {
|
||||
var message map[string]interface{}
|
||||
json.Unmarshal(msg.Payload(), &message)
|
||||
fmt.Printf("Received on %s: %+v\n", msg.Topic(), message)
|
||||
})
|
||||
|
||||
opts.OnConnect = func(client mqtt.Client) {
|
||||
fmt.Println("Connected to MQTT broker")
|
||||
|
||||
// Subscribe to responses
|
||||
client.Subscribe(fmt.Sprintf("spec/%s/response", clientID), 1, nil)
|
||||
|
||||
// Read users
|
||||
readMsg := map[string]interface{}{
|
||||
"id": "msg-1",
|
||||
"type": "request",
|
||||
"operation": "read",
|
||||
"schema": "public",
|
||||
"entity": "users",
|
||||
"options": map[string]interface{}{
|
||||
"filters": []map[string]interface{}{
|
||||
{"column": "status", "operator": "eq", "value": "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(readMsg)
|
||||
client.Publish(fmt.Sprintf("spec/%s/request", clientID), 1, false, payload)
|
||||
}
|
||||
|
||||
client := mqtt.NewClient(opts)
|
||||
if token := client.Connect(); token.Wait() && token.Error() != nil {
|
||||
panic(token.Error())
|
||||
}
|
||||
|
||||
// Keep running
|
||||
select {}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### BrokerConfig (Embedded Broker)
|
||||
|
||||
```go
|
||||
type BrokerConfig struct {
|
||||
Host string // Default: "localhost"
|
||||
Port int // Default: 1883
|
||||
EnableWebSocket bool // Enable WebSocket listener
|
||||
WSPort int // WebSocket port (default: 1884)
|
||||
MaxConnections int // Max concurrent connections
|
||||
KeepAlive time.Duration // MQTT keep-alive interval
|
||||
EnableAuth bool // Enable authentication
|
||||
}
|
||||
```
|
||||
|
||||
### ExternalBrokerConfig
|
||||
|
||||
```go
|
||||
type ExternalBrokerConfig struct {
|
||||
BrokerURL string // MQTT broker URL (tcp://host:port)
|
||||
ClientID string // MQTT client ID
|
||||
Username string // MQTT username
|
||||
Password string // MQTT password
|
||||
CleanSession bool // Clean session flag
|
||||
KeepAlive time.Duration // Keep-alive interval
|
||||
ConnectTimeout time.Duration // Connection timeout
|
||||
ReconnectDelay time.Duration // Auto-reconnect delay
|
||||
MaxReconnect int // Max reconnect attempts
|
||||
TLSConfig *tls.Config // TLS configuration
|
||||
}
|
||||
```
|
||||
|
||||
### QoS Configuration
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithQoS(1, 1, 1), // Request, Response, Notification
|
||||
)
|
||||
```
|
||||
|
||||
### Topic Prefix
|
||||
|
||||
```go
|
||||
handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
mqttspec.WithTopicPrefix("myapp"), // Changes topics to myapp/{client_id}/...
|
||||
)
|
||||
```
|
||||
|
||||
## Documentation References
|
||||
|
||||
- **ResolveSpec JSON Protocol**: See `/pkg/resolvespec/README.md` for the full message protocol specification
|
||||
- **WebSocketSpec Documentation**: See `/pkg/websocketspec/README.md` for similar WebSocket-based implementation
|
||||
- **Common Interfaces**: See `/pkg/common/types.go` for database adapter interfaces and query options
|
||||
- **Model Registry**: See `/pkg/modelregistry/README.md` for model registration and reflection
|
||||
- **Hooks Reference**: See `/pkg/websocketspec/hooks.go` for hook types (same as MQTTSpec)
|
||||
- **Subscription Management**: See `/pkg/websocketspec/subscription.go` for subscription filtering
|
||||
|
||||
## Comparison: MQTTSpec vs WebSocketSpec
|
||||
|
||||
| Feature | MQTTSpec | WebSocketSpec |
|
||||
|---------|----------|---------------|
|
||||
| **Transport** | MQTT (pub/sub broker) | WebSocket (direct connection) |
|
||||
| **Connection Model** | Broker-mediated | Direct client-server |
|
||||
| **QoS Levels** | QoS 0, 1, 2 support | No built-in QoS |
|
||||
| **Offline Messages** | Yes (with QoS 1+) | No |
|
||||
| **Auto-reconnect** | Yes (built into MQTT) | Manual implementation needed |
|
||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
||||
| **CRUD Operations** | Identical | Identical |
|
||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||
|
||||
## Use Cases
|
||||
|
||||
### IoT Sensor Data
|
||||
|
||||
```go
|
||||
// Sensors publish data, backend stores and notifies subscribers
|
||||
handler.Registry().RegisterModel("public.sensor_readings", &SensorReading{})
|
||||
|
||||
// Auto-set device_id from client metadata
|
||||
handler.Hooks().Register(mqttspec.BeforeCreate, func(ctx *mqttspec.HookContext) error {
|
||||
client := ctx.Metadata["mqtt_client"].(*mqttspec.Client)
|
||||
deviceID, _ := client.GetMetadata("device_id")
|
||||
|
||||
if ctx.Entity == "sensor_readings" {
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["device_id"] = deviceID
|
||||
dataMap["timestamp"] = time.Now()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Mobile App with Offline Support
|
||||
|
||||
MQTTSpec's QoS 1 ensures messages are delivered even if the client temporarily disconnects.
|
||||
|
||||
### Distributed Microservices
|
||||
|
||||
Multiple services can subscribe to entity changes and react accordingly.
|
||||
|
||||
## Testing
|
||||
|
||||
Run unit tests:
|
||||
|
||||
```bash
|
||||
go test -v ./pkg/mqttspec
|
||||
```
|
||||
|
||||
Run with race detection:
|
||||
|
||||
```bash
|
||||
go test -race -v ./pkg/mqttspec
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This package is part of the ResolveSpec project.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please ensure:
|
||||
|
||||
- All tests pass (`go test ./pkg/mqttspec`)
|
||||
- No race conditions (`go test -race ./pkg/mqttspec`)
|
||||
- Documentation is updated
|
||||
- Examples are provided for new features
|
||||
|
||||
## Support
|
||||
|
||||
For issues, questions, or feature requests, please open an issue in the ResolveSpec repository.
|
||||
417
pkg/mqttspec/broker.go
Normal file
417
pkg/mqttspec/broker.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mqtt "github.com/mochi-mqtt/server/v2"
|
||||
"github.com/mochi-mqtt/server/v2/listeners"
|
||||
|
||||
pahomqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// BrokerInterface abstracts MQTT broker operations
|
||||
type BrokerInterface interface {
|
||||
// Start initializes the broker/client connection
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// Stop gracefully shuts down the broker/client
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Publish sends a message to a topic
|
||||
Publish(topic string, qos byte, payload []byte) error
|
||||
|
||||
// Subscribe subscribes to a topic pattern with callback
|
||||
Subscribe(topicFilter string, qos byte, callback MessageCallback) error
|
||||
|
||||
// Unsubscribe removes subscription
|
||||
Unsubscribe(topicFilter string) error
|
||||
|
||||
// IsConnected returns connection status
|
||||
IsConnected() bool
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
GetClientManager() *ClientManager
|
||||
|
||||
// SetHandler sets the handler reference (needed for hooks)
|
||||
SetHandler(handler *Handler)
|
||||
}
|
||||
|
||||
// MessageCallback is called when a message arrives
|
||||
type MessageCallback func(topic string, payload []byte)
|
||||
|
||||
// EmbeddedBroker wraps Mochi MQTT server
|
||||
type EmbeddedBroker struct {
|
||||
config BrokerConfig
|
||||
server *mqtt.Server
|
||||
clientManager *ClientManager
|
||||
handler *Handler
|
||||
subscriptions map[string]MessageCallback
|
||||
subMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
}
|
||||
|
||||
// NewEmbeddedBroker creates a new embedded broker
|
||||
func NewEmbeddedBroker(config BrokerConfig, clientManager *ClientManager) *EmbeddedBroker {
|
||||
return &EmbeddedBroker{
|
||||
config: config,
|
||||
clientManager: clientManager,
|
||||
subscriptions: make(map[string]MessageCallback),
|
||||
}
|
||||
}
|
||||
|
||||
// SetHandler sets the handler reference
|
||||
func (eb *EmbeddedBroker) SetHandler(handler *Handler) {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
eb.handler = handler
|
||||
}
|
||||
|
||||
// Start starts the embedded MQTT broker
|
||||
func (eb *EmbeddedBroker) Start(ctx context.Context) error {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
if eb.started {
|
||||
return fmt.Errorf("broker already started")
|
||||
}
|
||||
|
||||
eb.ctx, eb.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Create Mochi MQTT server
|
||||
eb.server = mqtt.New(&mqtt.Options{
|
||||
InlineClient: true,
|
||||
})
|
||||
|
||||
// Note: Authentication is handled at the handler level via BeforeConnect hook
|
||||
// Mochi MQTT auth can be configured via custom hooks if needed
|
||||
|
||||
// Add TCP listener
|
||||
tcp := listeners.NewTCP(
|
||||
listeners.Config{
|
||||
ID: "tcp",
|
||||
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.Port),
|
||||
},
|
||||
)
|
||||
if err := eb.server.AddListener(tcp); err != nil {
|
||||
return fmt.Errorf("failed to add TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// Add WebSocket listener if enabled
|
||||
if eb.config.EnableWebSocket {
|
||||
ws := listeners.NewWebsocket(
|
||||
listeners.Config{
|
||||
ID: "ws",
|
||||
Address: fmt.Sprintf("%s:%d", eb.config.Host, eb.config.WSPort),
|
||||
},
|
||||
)
|
||||
if err := eb.server.AddListener(ws); err != nil {
|
||||
return fmt.Errorf("failed to add WebSocket listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := eb.server.Serve(); err != nil {
|
||||
logger.Error("[MQTTSpec] Embedded broker error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for server to be ready
|
||||
select {
|
||||
case <-time.After(2 * time.Second):
|
||||
// Server should be ready
|
||||
case <-eb.ctx.Done():
|
||||
return fmt.Errorf("context cancelled during startup")
|
||||
}
|
||||
|
||||
eb.started = true
|
||||
logger.Info("[MQTTSpec] Embedded broker started on %s:%d", eb.config.Host, eb.config.Port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the embedded broker
|
||||
func (eb *EmbeddedBroker) Stop(ctx context.Context) error {
|
||||
eb.mu.Lock()
|
||||
defer eb.mu.Unlock()
|
||||
|
||||
if !eb.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
if eb.cancel != nil {
|
||||
eb.cancel()
|
||||
}
|
||||
|
||||
if eb.server != nil {
|
||||
if err := eb.server.Close(); err != nil {
|
||||
logger.Error("[MQTTSpec] Error closing embedded broker: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eb.started = false
|
||||
logger.Info("[MQTTSpec] Embedded broker stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish publishes a message to a topic
|
||||
func (eb *EmbeddedBroker) Publish(topic string, qos byte, payload []byte) error {
|
||||
if !eb.started {
|
||||
return fmt.Errorf("broker not started")
|
||||
}
|
||||
|
||||
if eb.server == nil {
|
||||
return fmt.Errorf("server not initialized")
|
||||
}
|
||||
|
||||
// Use inline client to publish
|
||||
return eb.server.Publish(topic, payload, false, qos)
|
||||
}
|
||||
|
||||
// Subscribe subscribes to a topic
|
||||
func (eb *EmbeddedBroker) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
|
||||
if !eb.started {
|
||||
return fmt.Errorf("broker not started")
|
||||
}
|
||||
|
||||
// Store callback
|
||||
eb.subMu.Lock()
|
||||
eb.subscriptions[topicFilter] = callback
|
||||
eb.subMu.Unlock()
|
||||
|
||||
// Create inline subscription handler
|
||||
// Note: Mochi MQTT internal subscriptions are more complex
|
||||
// For now, we'll use a publishing hook to intercept messages
|
||||
// This is a simplified implementation
|
||||
|
||||
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from a topic
|
||||
func (eb *EmbeddedBroker) Unsubscribe(topicFilter string) error {
|
||||
eb.subMu.Lock()
|
||||
defer eb.subMu.Unlock()
|
||||
|
||||
delete(eb.subscriptions, topicFilter)
|
||||
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected returns whether the broker is running
|
||||
func (eb *EmbeddedBroker) IsConnected() bool {
|
||||
eb.mu.RLock()
|
||||
defer eb.mu.RUnlock()
|
||||
return eb.started
|
||||
}
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
func (eb *EmbeddedBroker) GetClientManager() *ClientManager {
|
||||
return eb.clientManager
|
||||
}
|
||||
|
||||
// ExternalBrokerClient wraps Paho MQTT client
|
||||
type ExternalBrokerClient struct {
|
||||
config ExternalBrokerConfig
|
||||
client pahomqtt.Client
|
||||
clientManager *ClientManager
|
||||
handler *Handler
|
||||
subscriptions map[string]MessageCallback
|
||||
subMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
connected bool
|
||||
}
|
||||
|
||||
// NewExternalBrokerClient creates a new external broker client
|
||||
func NewExternalBrokerClient(config ExternalBrokerConfig, clientManager *ClientManager) *ExternalBrokerClient {
|
||||
return &ExternalBrokerClient{
|
||||
config: config,
|
||||
clientManager: clientManager,
|
||||
subscriptions: make(map[string]MessageCallback),
|
||||
}
|
||||
}
|
||||
|
||||
// SetHandler sets the handler reference
|
||||
func (ebc *ExternalBrokerClient) SetHandler(handler *Handler) {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
ebc.handler = handler
|
||||
}
|
||||
|
||||
// Start connects to the external MQTT broker
|
||||
func (ebc *ExternalBrokerClient) Start(ctx context.Context) error {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
|
||||
if ebc.connected {
|
||||
return fmt.Errorf("already connected")
|
||||
}
|
||||
|
||||
ebc.ctx, ebc.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Create Paho client options
|
||||
opts := pahomqtt.NewClientOptions()
|
||||
opts.AddBroker(ebc.config.BrokerURL)
|
||||
opts.SetClientID(ebc.config.ClientID)
|
||||
opts.SetUsername(ebc.config.Username)
|
||||
opts.SetPassword(ebc.config.Password)
|
||||
opts.SetCleanSession(ebc.config.CleanSession)
|
||||
opts.SetKeepAlive(ebc.config.KeepAlive)
|
||||
opts.SetAutoReconnect(true)
|
||||
opts.SetMaxReconnectInterval(ebc.config.ReconnectDelay)
|
||||
|
||||
// Set connection lost handler
|
||||
opts.SetConnectionLostHandler(func(client pahomqtt.Client, err error) {
|
||||
logger.Error("[MQTTSpec] External broker connection lost: %v", err)
|
||||
ebc.mu.Lock()
|
||||
ebc.connected = false
|
||||
ebc.mu.Unlock()
|
||||
})
|
||||
|
||||
// Set on-connect handler
|
||||
opts.SetOnConnectHandler(func(client pahomqtt.Client) {
|
||||
logger.Info("[MQTTSpec] Connected to external broker")
|
||||
ebc.mu.Lock()
|
||||
ebc.connected = true
|
||||
ebc.mu.Unlock()
|
||||
|
||||
// Resubscribe to topics
|
||||
ebc.resubscribeAll()
|
||||
})
|
||||
|
||||
// Create and connect client
|
||||
ebc.client = pahomqtt.NewClient(opts)
|
||||
token := ebc.client.Connect()
|
||||
|
||||
if !token.WaitTimeout(ebc.config.ConnectTimeout) {
|
||||
return fmt.Errorf("connection timeout")
|
||||
}
|
||||
|
||||
if err := token.Error(); err != nil {
|
||||
return fmt.Errorf("failed to connect to external broker: %w", err)
|
||||
}
|
||||
|
||||
ebc.connected = true
|
||||
logger.Info("[MQTTSpec] Connected to external MQTT broker: %s", ebc.config.BrokerURL)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop disconnects from the external broker
|
||||
func (ebc *ExternalBrokerClient) Stop(ctx context.Context) error {
|
||||
ebc.mu.Lock()
|
||||
defer ebc.mu.Unlock()
|
||||
|
||||
if !ebc.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ebc.cancel != nil {
|
||||
ebc.cancel()
|
||||
}
|
||||
|
||||
if ebc.client != nil && ebc.client.IsConnected() {
|
||||
ebc.client.Disconnect(uint(ebc.config.ConnectTimeout.Milliseconds()))
|
||||
}
|
||||
|
||||
ebc.connected = false
|
||||
logger.Info("[MQTTSpec] Disconnected from external broker")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish publishes a message to a topic
|
||||
func (ebc *ExternalBrokerClient) Publish(topic string, qos byte, payload []byte) error {
|
||||
if !ebc.connected {
|
||||
return fmt.Errorf("not connected to broker")
|
||||
}
|
||||
|
||||
token := ebc.client.Publish(topic, qos, false, payload)
|
||||
token.Wait()
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
// Subscribe subscribes to a topic
|
||||
func (ebc *ExternalBrokerClient) Subscribe(topicFilter string, qos byte, callback MessageCallback) error {
|
||||
if !ebc.connected {
|
||||
return fmt.Errorf("not connected to broker")
|
||||
}
|
||||
|
||||
// Store callback
|
||||
ebc.subMu.Lock()
|
||||
ebc.subscriptions[topicFilter] = callback
|
||||
ebc.subMu.Unlock()
|
||||
|
||||
// Subscribe via Paho client
|
||||
token := ebc.client.Subscribe(topicFilter, qos, func(client pahomqtt.Client, msg pahomqtt.Message) {
|
||||
callback(msg.Topic(), msg.Payload())
|
||||
})
|
||||
|
||||
token.Wait()
|
||||
if err := token.Error(); err != nil {
|
||||
return fmt.Errorf("failed to subscribe to %s: %w", topicFilter, err)
|
||||
}
|
||||
|
||||
logger.Info("[MQTTSpec] Subscribed to topic filter: %s", topicFilter)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes from a topic
|
||||
func (ebc *ExternalBrokerClient) Unsubscribe(topicFilter string) error {
|
||||
ebc.subMu.Lock()
|
||||
defer ebc.subMu.Unlock()
|
||||
|
||||
if ebc.client != nil && ebc.connected {
|
||||
token := ebc.client.Unsubscribe(topicFilter)
|
||||
token.Wait()
|
||||
if err := token.Error(); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to unsubscribe from %s: %v", topicFilter, err)
|
||||
}
|
||||
}
|
||||
|
||||
delete(ebc.subscriptions, topicFilter)
|
||||
logger.Info("[MQTTSpec] Unsubscribed from topic filter: %s", topicFilter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected returns connection status
|
||||
func (ebc *ExternalBrokerClient) IsConnected() bool {
|
||||
ebc.mu.RLock()
|
||||
defer ebc.mu.RUnlock()
|
||||
return ebc.connected
|
||||
}
|
||||
|
||||
// GetClientManager returns the client manager
|
||||
func (ebc *ExternalBrokerClient) GetClientManager() *ClientManager {
|
||||
return ebc.clientManager
|
||||
}
|
||||
|
||||
// resubscribeAll resubscribes to all topics after reconnection
|
||||
func (ebc *ExternalBrokerClient) resubscribeAll() {
|
||||
ebc.subMu.RLock()
|
||||
defer ebc.subMu.RUnlock()
|
||||
|
||||
for topicFilter, callback := range ebc.subscriptions {
|
||||
logger.Info("[MQTTSpec] Resubscribing to topic: %s", topicFilter)
|
||||
token := ebc.client.Subscribe(topicFilter, 1, func(client pahomqtt.Client, msg pahomqtt.Message) {
|
||||
callback(msg.Topic(), msg.Payload())
|
||||
})
|
||||
if token.Wait() && token.Error() != nil {
|
||||
logger.Error("[MQTTSpec] Failed to resubscribe to %s: %v", topicFilter, token.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
409
pkg/mqttspec/broker_test.go
Normal file
409
pkg/mqttspec/broker_test.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewEmbeddedBroker(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 1883,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
assert.NotNil(t, broker)
|
||||
assert.Equal(t, config, broker.config)
|
||||
assert.Equal(t, cm, broker.clientManager)
|
||||
assert.NotNil(t, broker.subscriptions)
|
||||
assert.False(t, broker.started)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StartStop(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11883, // Use non-standard port for testing
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify started
|
||||
assert.True(t, broker.IsConnected())
|
||||
|
||||
// Stop broker
|
||||
err = broker.Stop(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify stopped
|
||||
assert.False(t, broker.IsConnected())
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StartTwice(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11884,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Try to start again - should fail
|
||||
err = broker.Start(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_StopWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11885,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Stop without starting - should not error
|
||||
err := broker.Stop(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_PublishWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11886,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Try to publish without starting - should fail
|
||||
err := broker.Publish("test/topic", 1, []byte("test"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "broker not started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_SubscribeWithoutStart(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11887,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Try to subscribe without starting - should fail
|
||||
err := broker.Subscribe("test/topic", 1, func(topic string, payload []byte) {})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "broker not started")
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_PublishSubscribe(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11888,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Subscribe to topic
|
||||
callback := func(topic string, payload []byte) {
|
||||
// Callback for subscription - actual message delivery would require
|
||||
// integration with Mochi MQTT's hook system
|
||||
}
|
||||
|
||||
err = broker.Subscribe("test/topic", 1, callback)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: Embedded broker's Subscribe is simplified and doesn't fully integrate
|
||||
// with Mochi MQTT's internal pub/sub. This test verifies the subscription
|
||||
// is registered but actual message delivery would require more complex
|
||||
// integration with Mochi MQTT's hook system.
|
||||
|
||||
// Verify subscription was registered
|
||||
broker.subMu.RLock()
|
||||
_, exists := broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_Unsubscribe(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11889,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Subscribe
|
||||
callback := func(topic string, payload []byte) {}
|
||||
err = broker.Subscribe("test/topic", 1, callback)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription exists
|
||||
broker.subMu.RLock()
|
||||
_, exists := broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.True(t, exists)
|
||||
|
||||
// Unsubscribe
|
||||
err = broker.Unsubscribe("test/topic")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subscription removed
|
||||
broker.subMu.RLock()
|
||||
_, exists = broker.subscriptions["test/topic"]
|
||||
broker.subMu.RUnlock()
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_SetHandler(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11890,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Create a mock handler (nil is fine for this test)
|
||||
var handler *Handler = nil
|
||||
|
||||
// Set handler
|
||||
broker.SetHandler(handler)
|
||||
|
||||
// Verify handler was set
|
||||
broker.mu.RLock()
|
||||
assert.Equal(t, handler, broker.handler)
|
||||
broker.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_GetClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11891,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
|
||||
// Get client manager
|
||||
retrievedCM := broker.GetClientManager()
|
||||
|
||||
// Verify it's the same instance
|
||||
assert.Equal(t, cm, retrievedCM)
|
||||
}
|
||||
|
||||
func TestEmbeddedBroker_ConcurrentPublish(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 11892,
|
||||
MaxConnections: 100,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewEmbeddedBroker(config, cm)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start broker
|
||||
err := broker.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
defer broker.Stop(ctx)
|
||||
|
||||
// Test concurrent publishing
|
||||
var wg sync.WaitGroup
|
||||
numPublishers := 10
|
||||
|
||||
for i := 0; i < numPublishers; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
err := broker.Publish("test/topic", 1, []byte("test"))
|
||||
// Errors are acceptable in concurrent scenario
|
||||
_ = err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNewExternalBrokerClient(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
assert.NotNil(t, broker)
|
||||
assert.Equal(t, config, broker.config)
|
||||
assert.Equal(t, cm, broker.clientManager)
|
||||
assert.NotNil(t, broker.subscriptions)
|
||||
assert.False(t, broker.connected)
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_SetHandler(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Create a mock handler (nil is fine for this test)
|
||||
var handler *Handler = nil
|
||||
|
||||
// Set handler
|
||||
broker.SetHandler(handler)
|
||||
|
||||
// Verify handler was set
|
||||
broker.mu.RLock()
|
||||
assert.Equal(t, handler, broker.handler)
|
||||
broker.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_GetClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Get client manager
|
||||
retrievedCM := broker.GetClientManager()
|
||||
|
||||
// Verify it's the same instance
|
||||
assert.Equal(t, cm, retrievedCM)
|
||||
}
|
||||
|
||||
func TestExternalBrokerClient_IsConnected(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
config := ExternalBrokerConfig{
|
||||
BrokerURL: "tcp://localhost:1883",
|
||||
ClientID: "test-client",
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReconnectDelay: 1 * time.Second,
|
||||
}
|
||||
|
||||
broker := NewExternalBrokerClient(config, cm)
|
||||
|
||||
// Should not be connected initially
|
||||
assert.False(t, broker.IsConnected())
|
||||
}
|
||||
|
||||
// Note: Tests for ExternalBrokerClient Start/Stop/Publish/Subscribe require
|
||||
// a running MQTT broker and are better suited for integration tests.
|
||||
// These tests would be included in integration_test.go with proper test
|
||||
// broker setup (e.g., using Docker Compose).
|
||||
184
pkg/mqttspec/client.go
Normal file
184
pkg/mqttspec/client.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Client represents an MQTT client connection
|
||||
type Client struct {
|
||||
// ID is the MQTT client ID (unique per connection)
|
||||
ID string
|
||||
|
||||
// Username from MQTT CONNECT packet
|
||||
Username string
|
||||
|
||||
// ConnectedAt is when the client connected
|
||||
ConnectedAt time.Time
|
||||
|
||||
// subscriptions holds active subscriptions for this client
|
||||
subscriptions map[string]*Subscription
|
||||
subMu sync.RWMutex
|
||||
|
||||
// metadata stores client-specific data (user_id, roles, tenant_id, etc.)
|
||||
// Set by BeforeConnect hook for authentication/authorization
|
||||
metadata map[string]interface{}
|
||||
metaMu sync.RWMutex
|
||||
|
||||
// ctx is the client context
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// handler reference for callback access
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
// ClientManager manages all MQTT client connections
|
||||
type ClientManager struct {
|
||||
// clients maps client_id to Client
|
||||
clients map[string]*Client
|
||||
mu sync.RWMutex
|
||||
|
||||
// ctx for lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewClient creates a new MQTT client
|
||||
func NewClient(id, username string, handler *Handler) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
ID: id,
|
||||
Username: username,
|
||||
ConnectedAt: time.Now(),
|
||||
subscriptions: make(map[string]*Subscription),
|
||||
metadata: make(map[string]interface{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetadata sets metadata for this client
|
||||
func (c *Client) SetMetadata(key string, value interface{}) {
|
||||
c.metaMu.Lock()
|
||||
defer c.metaMu.Unlock()
|
||||
c.metadata[key] = value
|
||||
}
|
||||
|
||||
// GetMetadata retrieves metadata for this client
|
||||
func (c *Client) GetMetadata(key string) (interface{}, bool) {
|
||||
c.metaMu.RLock()
|
||||
defer c.metaMu.RUnlock()
|
||||
val, ok := c.metadata[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// AddSubscription adds a subscription to this client
|
||||
func (c *Client) AddSubscription(sub *Subscription) {
|
||||
c.subMu.Lock()
|
||||
defer c.subMu.Unlock()
|
||||
c.subscriptions[sub.ID] = sub
|
||||
}
|
||||
|
||||
// RemoveSubscription removes a subscription from this client
|
||||
func (c *Client) RemoveSubscription(subID string) {
|
||||
c.subMu.Lock()
|
||||
defer c.subMu.Unlock()
|
||||
delete(c.subscriptions, subID)
|
||||
}
|
||||
|
||||
// GetSubscription retrieves a subscription by ID
|
||||
func (c *Client) GetSubscription(subID string) (*Subscription, bool) {
|
||||
c.subMu.RLock()
|
||||
defer c.subMu.RUnlock()
|
||||
sub, ok := c.subscriptions[subID]
|
||||
return sub, ok
|
||||
}
|
||||
|
||||
// Close cleans up the client
|
||||
func (c *Client) Close() {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// Clean up subscriptions
|
||||
c.subMu.Lock()
|
||||
for subID := range c.subscriptions {
|
||||
if c.handler != nil && c.handler.subscriptionManager != nil {
|
||||
c.handler.subscriptionManager.Unsubscribe(subID)
|
||||
}
|
||||
}
|
||||
c.subscriptions = make(map[string]*Subscription)
|
||||
c.subMu.Unlock()
|
||||
}
|
||||
|
||||
// NewClientManager creates a new client manager
|
||||
func NewClientManager(ctx context.Context) *ClientManager {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &ClientManager{
|
||||
clients: make(map[string]*Client),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new MQTT client
|
||||
func (cm *ClientManager) Register(clientID, username string, handler *Handler) *Client {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
|
||||
client := NewClient(clientID, username, handler)
|
||||
cm.clients[clientID] = client
|
||||
|
||||
count := len(cm.clients)
|
||||
logger.Info("[MQTTSpec] Client registered: %s (username: %s, total: %d)", clientID, username, count)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// Unregister removes a client
|
||||
func (cm *ClientManager) Unregister(clientID string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
|
||||
if client, ok := cm.clients[clientID]; ok {
|
||||
client.Close()
|
||||
delete(cm.clients, clientID)
|
||||
count := len(cm.clients)
|
||||
logger.Info("[MQTTSpec] Client unregistered: %s (total: %d)", clientID, count)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient retrieves a client by ID
|
||||
func (cm *ClientManager) GetClient(clientID string) (*Client, bool) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
client, ok := cm.clients[clientID]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// Count returns the number of active clients
|
||||
func (cm *ClientManager) Count() int {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return len(cm.clients)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the client manager
|
||||
func (cm *ClientManager) Shutdown() {
|
||||
cm.cancel()
|
||||
|
||||
// Close all clients
|
||||
cm.mu.Lock()
|
||||
for _, client := range cm.clients {
|
||||
client.Close()
|
||||
}
|
||||
cm.clients = make(map[string]*Client)
|
||||
cm.mu.Unlock()
|
||||
|
||||
logger.Info("[MQTTSpec] Client manager shut down")
|
||||
}
|
||||
256
pkg/mqttspec/client_test.go
Normal file
256
pkg/mqttspec/client_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient("client-123", "user@example.com", nil)
|
||||
|
||||
assert.Equal(t, "client-123", client.ID)
|
||||
assert.Equal(t, "user@example.com", client.Username)
|
||||
assert.NotNil(t, client.subscriptions)
|
||||
assert.NotNil(t, client.metadata)
|
||||
assert.NotNil(t, client.ctx)
|
||||
assert.NotNil(t, client.cancel)
|
||||
}
|
||||
|
||||
func TestClient_Metadata(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Set metadata
|
||||
client.SetMetadata("user_id", 456)
|
||||
client.SetMetadata("tenant_id", "tenant-abc")
|
||||
client.SetMetadata("roles", []string{"admin", "user"})
|
||||
|
||||
// Get metadata
|
||||
userID, exists := client.GetMetadata("user_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 456, userID)
|
||||
|
||||
tenantID, exists := client.GetMetadata("tenant_id")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "tenant-abc", tenantID)
|
||||
|
||||
roles, exists := client.GetMetadata("roles")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []string{"admin", "user"}, roles)
|
||||
|
||||
// Non-existent key
|
||||
_, exists = client.GetMetadata("nonexistent")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClient_Subscriptions(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Create mock subscription
|
||||
sub := &Subscription{
|
||||
ID: "sub-1",
|
||||
ConnectionID: "client-123",
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Active: true,
|
||||
}
|
||||
|
||||
// Add subscription
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Get subscription
|
||||
retrieved, exists := client.GetSubscription("sub-1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "sub-1", retrieved.ID)
|
||||
|
||||
// Remove subscription
|
||||
client.RemoveSubscription("sub-1")
|
||||
|
||||
// Verify removed
|
||||
_, exists = client.GetSubscription("sub-1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClient_Close(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
// Add some subscriptions
|
||||
client.AddSubscription(&Subscription{ID: "sub-1"})
|
||||
client.AddSubscription(&Subscription{ID: "sub-2"})
|
||||
|
||||
// Close client
|
||||
client.Close()
|
||||
|
||||
// Verify subscriptions cleared
|
||||
client.subMu.RLock()
|
||||
assert.Empty(t, client.subscriptions)
|
||||
client.subMu.RUnlock()
|
||||
|
||||
// Verify context cancelled
|
||||
select {
|
||||
case <-client.ctx.Done():
|
||||
// Context was cancelled
|
||||
default:
|
||||
t.Fatal("Context should be cancelled after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientManager(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
|
||||
assert.NotNil(t, cm)
|
||||
assert.NotNil(t, cm.clients)
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_Register(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
client := cm.Register("client-1", "user@example.com", nil)
|
||||
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, "client-1", client.ID)
|
||||
assert.Equal(t, "user@example.com", client.Username)
|
||||
assert.Equal(t, 1, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_Unregister(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
assert.Equal(t, 1, cm.Count())
|
||||
|
||||
cm.Unregister("client-1")
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
}
|
||||
|
||||
func TestClientManager_GetClient(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
|
||||
// Get existing client
|
||||
client, exists := cm.GetClient("client-1")
|
||||
assert.True(t, exists)
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, "client-1", client.ID)
|
||||
|
||||
// Get non-existent client
|
||||
_, exists = cm.GetClient("nonexistent")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestClientManager_MultipleClients(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
cm.Register("client-2", "user2", nil)
|
||||
cm.Register("client-3", "user3", nil)
|
||||
|
||||
assert.Equal(t, 3, cm.Count())
|
||||
|
||||
cm.Unregister("client-2")
|
||||
assert.Equal(t, 2, cm.Count())
|
||||
|
||||
// Verify correct client was removed
|
||||
_, exists := cm.GetClient("client-2")
|
||||
assert.False(t, exists)
|
||||
|
||||
_, exists = cm.GetClient("client-1")
|
||||
assert.True(t, exists)
|
||||
|
||||
_, exists = cm.GetClient("client-3")
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestClientManager_Shutdown(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
|
||||
cm.Register("client-1", "user1", nil)
|
||||
cm.Register("client-2", "user2", nil)
|
||||
assert.Equal(t, 2, cm.Count())
|
||||
|
||||
cm.Shutdown()
|
||||
|
||||
// All clients should be removed
|
||||
assert.Equal(t, 0, cm.Count())
|
||||
|
||||
// Context should be cancelled
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
// Context was cancelled
|
||||
default:
|
||||
t.Fatal("Context should be cancelled after Shutdown()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientManager_ConcurrentOperations(t *testing.T) {
|
||||
cm := NewClientManager(context.Background())
|
||||
defer cm.Shutdown()
|
||||
|
||||
// This test verifies that concurrent operations don't cause race conditions
|
||||
// Run with: go test -race
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Goroutine 1: Register clients
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.Register("client-"+string(rune(i)), "user", nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 2: Get clients
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.GetClient("client-" + string(rune(i)))
|
||||
}
|
||||
}()
|
||||
|
||||
// Goroutine 3: Count
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
cm.Count()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestClient_ConcurrentMetadata(t *testing.T) {
|
||||
client := NewClient("client-123", "user", nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent writes
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
client.SetMetadata("key1", i)
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
client.GetMetadata("key1")
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
178
pkg/mqttspec/config.go
Normal file
178
pkg/mqttspec/config.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BrokerMode specifies how to connect to MQTT
|
||||
type BrokerMode string
|
||||
|
||||
const (
|
||||
// BrokerModeEmbedded runs Mochi MQTT broker in-process
|
||||
BrokerModeEmbedded BrokerMode = "embedded"
|
||||
// BrokerModeExternal connects to external MQTT broker as client
|
||||
BrokerModeExternal BrokerMode = "external"
|
||||
)
|
||||
|
||||
// Config holds all mqttspec configuration
|
||||
type Config struct {
|
||||
// BrokerMode determines whether to use embedded or external broker
|
||||
BrokerMode BrokerMode
|
||||
|
||||
// Broker configuration for embedded mode
|
||||
Broker BrokerConfig
|
||||
|
||||
// ExternalBroker configuration for external client mode
|
||||
ExternalBroker ExternalBrokerConfig
|
||||
|
||||
// Topics configuration
|
||||
Topics TopicConfig
|
||||
|
||||
// QoS configuration for different message types
|
||||
QoS QoSConfig
|
||||
|
||||
// Auth configuration
|
||||
Auth AuthConfig
|
||||
|
||||
// Timeouts for various operations
|
||||
Timeouts TimeoutConfig
|
||||
}
|
||||
|
||||
// BrokerConfig configures the embedded Mochi MQTT broker
|
||||
type BrokerConfig struct {
|
||||
// Host to bind to (default: "localhost")
|
||||
Host string
|
||||
|
||||
// Port to listen on (default: 1883)
|
||||
Port int
|
||||
|
||||
// EnableWebSocket enables WebSocket support
|
||||
EnableWebSocket bool
|
||||
|
||||
// WSPort is the WebSocket port (default: 8883)
|
||||
WSPort int
|
||||
|
||||
// MaxConnections limits concurrent client connections
|
||||
MaxConnections int
|
||||
|
||||
// KeepAlive is the client keepalive interval
|
||||
KeepAlive time.Duration
|
||||
|
||||
// EnableAuth enables username/password authentication
|
||||
EnableAuth bool
|
||||
}
|
||||
|
||||
// ExternalBrokerConfig for connecting as a client to external broker
|
||||
type ExternalBrokerConfig struct {
|
||||
// BrokerURL is the broker address (e.g., tcp://host:port or ssl://host:port)
|
||||
BrokerURL string
|
||||
|
||||
// ClientID is a unique identifier for this handler instance
|
||||
ClientID string
|
||||
|
||||
// Username for MQTT authentication
|
||||
Username string
|
||||
|
||||
// Password for MQTT authentication
|
||||
Password string
|
||||
|
||||
// CleanSession flag (default: true)
|
||||
CleanSession bool
|
||||
|
||||
// KeepAlive interval (default: 60s)
|
||||
KeepAlive time.Duration
|
||||
|
||||
// ConnectTimeout for initial connection (default: 30s)
|
||||
ConnectTimeout time.Duration
|
||||
|
||||
// ReconnectDelay between reconnection attempts (default: 5s)
|
||||
ReconnectDelay time.Duration
|
||||
|
||||
// MaxReconnect attempts (0 = unlimited, default: 0)
|
||||
MaxReconnect int
|
||||
|
||||
// TLSConfig for SSL/TLS connections
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TopicConfig defines the MQTT topic structure
|
||||
type TopicConfig struct {
|
||||
// Prefix for all topics (default: "spec")
|
||||
// Topics will be: {Prefix}/{client_id}/request|response|notify/{sub_id}
|
||||
Prefix string
|
||||
}
|
||||
|
||||
// QoSConfig defines quality of service levels for different message types
|
||||
type QoSConfig struct {
|
||||
// Request messages QoS (default: 1 - at-least-once)
|
||||
Request byte
|
||||
|
||||
// Response messages QoS (default: 1 - at-least-once)
|
||||
Response byte
|
||||
|
||||
// Notification messages QoS (default: 1 - at-least-once)
|
||||
Notification byte
|
||||
}
|
||||
|
||||
// AuthConfig for MQTT-level authentication
|
||||
type AuthConfig struct {
|
||||
// ValidateCredentials is called to validate username/password for embedded broker
|
||||
// Return true if credentials are valid, false otherwise
|
||||
ValidateCredentials func(username, password string) bool
|
||||
}
|
||||
|
||||
// TimeoutConfig defines timeouts for various operations
|
||||
type TimeoutConfig struct {
|
||||
// Connect timeout for MQTT connection (default: 30s)
|
||||
Connect time.Duration
|
||||
|
||||
// Publish timeout for publishing messages (default: 5s)
|
||||
Publish time.Duration
|
||||
|
||||
// Disconnect timeout for graceful shutdown (default: 10s)
|
||||
Disconnect time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns a configuration with sensible defaults
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
BrokerMode: BrokerModeEmbedded,
|
||||
Broker: BrokerConfig{
|
||||
Host: "localhost",
|
||||
Port: 1883,
|
||||
EnableWebSocket: false,
|
||||
WSPort: 8883,
|
||||
MaxConnections: 1000,
|
||||
KeepAlive: 60 * time.Second,
|
||||
EnableAuth: false,
|
||||
},
|
||||
ExternalBroker: ExternalBrokerConfig{
|
||||
BrokerURL: "",
|
||||
ClientID: "",
|
||||
Username: "",
|
||||
Password: "",
|
||||
CleanSession: true,
|
||||
KeepAlive: 60 * time.Second,
|
||||
ConnectTimeout: 30 * time.Second,
|
||||
ReconnectDelay: 5 * time.Second,
|
||||
MaxReconnect: 0, // Unlimited
|
||||
},
|
||||
Topics: TopicConfig{
|
||||
Prefix: "spec",
|
||||
},
|
||||
QoS: QoSConfig{
|
||||
Request: 1, // At-least-once
|
||||
Response: 1, // At-least-once
|
||||
Notification: 1, // At-least-once
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
ValidateCredentials: nil,
|
||||
},
|
||||
Timeouts: TimeoutConfig{
|
||||
Connect: 30 * time.Second,
|
||||
Publish: 5 * time.Second,
|
||||
Disconnect: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
846
pkg/mqttspec/handler.go
Normal file
846
pkg/mqttspec/handler.go
Normal file
@@ -0,0 +1,846 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler handles MQTT messages and operations
|
||||
type Handler struct {
|
||||
// Database adapter (GORM/Bun)
|
||||
db common.Database
|
||||
|
||||
// Model registry
|
||||
registry common.ModelRegistry
|
||||
|
||||
// Hook registry
|
||||
hooks *HookRegistry
|
||||
|
||||
// Client manager
|
||||
clientManager *ClientManager
|
||||
|
||||
// Subscription manager
|
||||
subscriptionManager *SubscriptionManager
|
||||
|
||||
// Broker interface (embedded or external)
|
||||
broker BrokerInterface
|
||||
|
||||
// Configuration
|
||||
config *Config
|
||||
|
||||
// Context for lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Started flag
|
||||
started bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new MQTT handler
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, config *Config) (*Handler, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
clientManager: NewClientManager(ctx),
|
||||
subscriptionManager: NewSubscriptionManager(),
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
started: false,
|
||||
}
|
||||
|
||||
// Initialize broker based on mode
|
||||
if config.BrokerMode == BrokerModeEmbedded {
|
||||
h.broker = NewEmbeddedBroker(config.Broker, h.clientManager)
|
||||
} else {
|
||||
h.broker = NewExternalBrokerClient(config.ExternalBroker, h.clientManager)
|
||||
}
|
||||
|
||||
// Set handler reference in broker
|
||||
h.broker.SetHandler(h)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Start initializes and starts the handler
|
||||
func (h *Handler) Start() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.started {
|
||||
return fmt.Errorf("handler already started")
|
||||
}
|
||||
|
||||
// Start broker
|
||||
if err := h.broker.Start(h.ctx); err != nil {
|
||||
return fmt.Errorf("failed to start broker: %w", err)
|
||||
}
|
||||
|
||||
// Subscribe to all request topics: spec/+/request
|
||||
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
|
||||
if err := h.broker.Subscribe(requestTopic, h.config.QoS.Request, h.handleIncomingMessage); err != nil {
|
||||
_ = h.broker.Stop(h.ctx)
|
||||
return fmt.Errorf("failed to subscribe to request topic: %w", err)
|
||||
}
|
||||
|
||||
h.started = true
|
||||
logger.Info("[MQTTSpec] Handler started, listening on topic: %s", requestTopic)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler
|
||||
func (h *Handler) Shutdown() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if !h.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("[MQTTSpec] Shutting down handler...")
|
||||
|
||||
// Execute disconnect hooks for all clients
|
||||
h.clientManager.mu.RLock()
|
||||
clients := make([]*Client, 0, len(h.clientManager.clients))
|
||||
for _, client := range h.clientManager.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
h.clientManager.mu.RUnlock()
|
||||
|
||||
for _, client := range clients {
|
||||
hookCtx := &HookContext{
|
||||
Context: h.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
_ = h.hooks.Execute(BeforeDisconnect, hookCtx)
|
||||
h.clientManager.Unregister(client.ID)
|
||||
_ = h.hooks.Execute(AfterDisconnect, hookCtx)
|
||||
}
|
||||
|
||||
// Unsubscribe from request topic
|
||||
requestTopic := fmt.Sprintf("%s/+/request", h.config.Topics.Prefix)
|
||||
_ = h.broker.Unsubscribe(requestTopic)
|
||||
|
||||
// Stop broker
|
||||
if err := h.broker.Stop(h.ctx); err != nil {
|
||||
logger.Error("[MQTTSpec] Error stopping broker: %v", err)
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
if h.cancel != nil {
|
||||
h.cancel()
|
||||
}
|
||||
|
||||
h.started = false
|
||||
logger.Info("[MQTTSpec] Handler stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// Registry returns the model registry
|
||||
func (h *Handler) Registry() common.ModelRegistry {
|
||||
return h.registry
|
||||
}
|
||||
|
||||
// GetDatabase returns the database adapter
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// GetRelationshipInfo is a placeholder for relationship detection
|
||||
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
||||
// TODO: Implement full relationship detection if needed
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleIncomingMessage is called when a message arrives on spec/+/request
|
||||
func (h *Handler) handleIncomingMessage(topic string, payload []byte) {
|
||||
// Extract client_id from topic: spec/{client_id}/request
|
||||
parts := strings.Split(topic, "/")
|
||||
if len(parts) < 3 {
|
||||
logger.Error("[MQTTSpec] Invalid topic format: %s", topic)
|
||||
return
|
||||
}
|
||||
clientID := parts[len(parts)-2] // Second to last part is client_id
|
||||
|
||||
// Parse message
|
||||
msg, err := ParseMessage(payload)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to parse message from %s: %v", clientID, err)
|
||||
h.sendError(clientID, "", "invalid_message", "Failed to parse message")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if !msg.IsValid() {
|
||||
logger.Error("[MQTTSpec] Invalid message from %s", clientID)
|
||||
h.sendError(clientID, msg.ID, "invalid_message", "Message validation failed")
|
||||
return
|
||||
}
|
||||
|
||||
// Get or register client
|
||||
client, exists := h.clientManager.GetClient(clientID)
|
||||
if !exists {
|
||||
// First request from this client - register it
|
||||
client = h.clientManager.Register(clientID, "", h)
|
||||
|
||||
// Execute connect hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: h.ctx,
|
||||
Handler: nil, // Not used for MQTT, handler ref stored in metadata if needed
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeConnect, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeConnect hook failed for %s: %v", clientID, err)
|
||||
h.sendError(clientID, msg.ID, "auth_error", err.Error())
|
||||
h.clientManager.Unregister(clientID)
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.hooks.Execute(AfterConnect, hookCtx)
|
||||
}
|
||||
|
||||
// Route message by type
|
||||
switch msg.Type {
|
||||
case MessageTypeRequest:
|
||||
h.handleRequest(client, msg)
|
||||
case MessageTypeSubscription:
|
||||
h.handleSubscription(client, msg)
|
||||
case MessageTypePing:
|
||||
h.handlePing(client, msg)
|
||||
default:
|
||||
h.sendError(clientID, msg.ID, "invalid_message_type", fmt.Sprintf("Unknown message type: %s", msg.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest processes CRUD requests
|
||||
func (h *Handler) handleRequest(client *Client, msg *Message) {
|
||||
ctx := client.ctx
|
||||
schema := msg.Schema
|
||||
entity := msg.Entity
|
||||
recordID := msg.RecordID
|
||||
|
||||
// Get model from registry
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Model not found for %s.%s: %v", schema, entity, err)
|
||||
h.sendError(client.ID, msg.ID, "model_not_found", fmt.Sprintf("Model not found: %s.%s", schema, entity))
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and unwrap model
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Model validation failed for %s.%s: %v", schema, entity, err)
|
||||
h.sendError(client.ID, msg.ID, "invalid_model", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
modelPtr := result.ModelPtr
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
ModelPtr: modelPtr,
|
||||
Options: msg.Options,
|
||||
ID: recordID,
|
||||
Data: msg.Data,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
h.handleRead(client, msg, hookCtx)
|
||||
case OperationCreate:
|
||||
h.handleCreate(client, msg, hookCtx)
|
||||
case OperationUpdate:
|
||||
h.handleUpdate(client, msg, hookCtx)
|
||||
case OperationDelete:
|
||||
h.handleDelete(client, msg, hookCtx)
|
||||
case OperationMeta:
|
||||
h.handleMeta(client, msg, hookCtx)
|
||||
default:
|
||||
h.sendError(client.ID, msg.ID, "invalid_operation", fmt.Sprintf("Unknown operation: %s", msg.Operation))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRead processes a read operation
|
||||
func (h *Handler) handleRead(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeRead hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform read operation
|
||||
var data interface{}
|
||||
var metadata map[string]interface{}
|
||||
var err error
|
||||
|
||||
if hookCtx.ID != "" {
|
||||
// Read single record by ID
|
||||
data, err = h.readByID(hookCtx)
|
||||
metadata = map[string]interface{}{"total": 1}
|
||||
} else {
|
||||
// Read multiple records
|
||||
data, metadata, err = h.readMultiple(hookCtx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Read operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "read_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterRead hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, metadata)
|
||||
}
|
||||
|
||||
// handleCreate processes a create operation
|
||||
func (h *Handler) handleCreate(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeCreate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform create operation
|
||||
data, err := h.create(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Create operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "create_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterCreate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationCreate, data)
|
||||
}
|
||||
|
||||
// handleUpdate processes an update operation
|
||||
func (h *Handler) handleUpdate(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeUpdate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform update operation
|
||||
data, err := h.update(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Update operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "update_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Update hook context
|
||||
hookCtx.Result = data
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterUpdate hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, hookCtx.Result, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationUpdate, data)
|
||||
}
|
||||
|
||||
// handleDelete processes a delete operation
|
||||
func (h *Handler) handleDelete(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeDelete hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Perform delete operation
|
||||
if err := h.delete(hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] Delete operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "delete_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Execute after hook
|
||||
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] AfterDelete hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{"deleted": true}, nil)
|
||||
|
||||
// Notify subscribers
|
||||
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationDelete, map[string]interface{}{
|
||||
"id": hookCtx.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// handleMeta processes a metadata request
|
||||
func (h *Handler) handleMeta(client *Client, msg *Message, hookCtx *HookContext) {
|
||||
metadata, err := h.getMetadata(hookCtx)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Meta operation failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "meta_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(client.ID, msg.ID, metadata, nil)
|
||||
}
|
||||
|
||||
// handleSubscription manages subscriptions
|
||||
func (h *Handler) handleSubscription(client *Client, msg *Message) {
|
||||
switch msg.Operation {
|
||||
case OperationSubscribe:
|
||||
h.handleSubscribe(client, msg)
|
||||
case OperationUnsubscribe:
|
||||
h.handleUnsubscribe(client, msg)
|
||||
default:
|
||||
h.sendError(client.ID, msg.ID, "invalid_subscription_operation", fmt.Sprintf("Unknown subscription operation: %s", msg.Operation))
|
||||
}
|
||||
}
|
||||
|
||||
// handleSubscribe creates a subscription
|
||||
func (h *Handler) handleSubscribe(client *Client, msg *Message) {
|
||||
// Generate subscription ID
|
||||
subID := uuid.New().String()
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: client.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Schema: msg.Schema,
|
||||
Entity: msg.Entity,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeSubscribe, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeSubscribe hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Create subscription
|
||||
sub := h.subscriptionManager.Subscribe(subID, client.ID, msg.Schema, msg.Entity, msg.Options)
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Execute after hook
|
||||
_ = h.hooks.Execute(AfterSubscribe, hookCtx)
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
|
||||
"subscription_id": subID,
|
||||
"schema": msg.Schema,
|
||||
"entity": msg.Entity,
|
||||
"notify_topic": h.getNotifyTopic(client.ID, subID),
|
||||
}, nil)
|
||||
|
||||
logger.Info("[MQTTSpec] Subscription created: %s for %s.%s (client: %s)", subID, msg.Schema, msg.Entity, client.ID)
|
||||
}
|
||||
|
||||
// handleUnsubscribe removes a subscription
|
||||
func (h *Handler) handleUnsubscribe(client *Client, msg *Message) {
|
||||
subID := msg.SubscriptionID
|
||||
if subID == "" {
|
||||
h.sendError(client.ID, msg.ID, "invalid_subscription", "Subscription ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: client.ctx,
|
||||
Handler: nil, // Not used for MQTT
|
||||
Message: msg,
|
||||
Metadata: map[string]interface{}{
|
||||
"mqtt_client": client,
|
||||
},
|
||||
}
|
||||
|
||||
// Execute before hook
|
||||
if err := h.hooks.Execute(BeforeUnsubscribe, hookCtx); err != nil {
|
||||
logger.Error("[MQTTSpec] BeforeUnsubscribe hook failed: %v", err)
|
||||
h.sendError(client.ID, msg.ID, "hook_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Remove subscription
|
||||
h.subscriptionManager.Unsubscribe(subID)
|
||||
client.RemoveSubscription(subID)
|
||||
|
||||
// Execute after hook
|
||||
_ = h.hooks.Execute(AfterUnsubscribe, hookCtx)
|
||||
|
||||
// Send response
|
||||
h.sendResponse(client.ID, msg.ID, map[string]interface{}{
|
||||
"unsubscribed": true,
|
||||
"subscription_id": subID,
|
||||
}, nil)
|
||||
|
||||
logger.Info("[MQTTSpec] Subscription removed: %s (client: %s)", subID, client.ID)
|
||||
}
|
||||
|
||||
// handlePing responds to ping messages
|
||||
func (h *Handler) handlePing(client *Client, msg *Message) {
|
||||
pong := &ResponseMessage{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypePong,
|
||||
Success: true,
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(pong)
|
||||
topic := h.getResponseTopic(client.ID)
|
||||
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
|
||||
}
|
||||
|
||||
// notifySubscribers sends notifications to subscribers
|
||||
func (h *Handler) notifySubscribers(schema, entity string, operation OperationType, data interface{}) {
|
||||
subscriptions := h.subscriptionManager.GetSubscriptionsByEntity(schema, entity)
|
||||
if len(subscriptions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range subscriptions {
|
||||
// Check if data matches subscription filters
|
||||
if !sub.MatchesFilters(data) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get client
|
||||
client, exists := h.clientManager.GetClient(sub.ConnectionID)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create notification message
|
||||
notification := NewNotificationMessage(sub.ID, operation, schema, entity, data)
|
||||
payload, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to marshal notification: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Publish to client's notify topic
|
||||
topic := h.getNotifyTopic(client.ID, sub.ID)
|
||||
if err := h.broker.Publish(topic, h.config.QoS.Notification, payload); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to publish notification to %s: %v", topic, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Response helpers
|
||||
|
||||
// sendResponse publishes a response message
|
||||
func (h *Handler) sendResponse(clientID, msgID string, data interface{}, metadata map[string]interface{}) {
|
||||
resp := NewResponseMessage(msgID, true, data)
|
||||
resp.Metadata = metadata
|
||||
|
||||
payload, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
topic := h.getResponseTopic(clientID)
|
||||
if err := h.broker.Publish(topic, h.config.QoS.Response, payload); err != nil {
|
||||
logger.Error("[MQTTSpec] Failed to publish response to %s: %v", topic, err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendError publishes an error response
|
||||
func (h *Handler) sendError(clientID, msgID, code, message string) {
|
||||
errResp := NewErrorResponse(msgID, code, message)
|
||||
|
||||
payload, _ := json.Marshal(errResp)
|
||||
topic := h.getResponseTopic(clientID)
|
||||
_ = h.broker.Publish(topic, h.config.QoS.Response, payload)
|
||||
}
|
||||
|
||||
// Topic helpers
|
||||
|
||||
func (h *Handler) getRequestTopic(clientID string) string {
|
||||
return fmt.Sprintf("%s/%s/request", h.config.Topics.Prefix, clientID)
|
||||
}
|
||||
|
||||
func (h *Handler) getResponseTopic(clientID string) string {
|
||||
return fmt.Sprintf("%s/%s/response", h.config.Topics.Prefix, clientID)
|
||||
}
|
||||
|
||||
func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
|
||||
return fmt.Sprintf("%s/%s/notify/%s", h.config.Topics.Prefix, clientID, subscriptionID)
|
||||
}
|
||||
|
||||
// Database operation helpers (adapted from websocketspec)
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
// readByID reads a single record by ID
|
||||
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
// Apply columns
|
||||
if hookCtx.Options != nil && len(hookCtx.Options.Columns) > 0 {
|
||||
query = query.Column(hookCtx.Options.Columns...)
|
||||
}
|
||||
|
||||
// Apply preloads (simplified)
|
||||
if hookCtx.Options != nil {
|
||||
for i := range hookCtx.Options.Preload {
|
||||
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := query.ScanModel(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to read record: %w", err)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
// readMultiple reads multiple records
|
||||
func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata map[string]interface{}, err error) {
|
||||
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Apply options
|
||||
if hookCtx.Options != nil {
|
||||
// Apply filters
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
for _, sort := range hookCtx.Options.Sort {
|
||||
direction := "ASC"
|
||||
if sort.Direction == "desc" {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Apply limit and offset
|
||||
if hookCtx.Options.Limit != nil {
|
||||
query = query.Limit(*hookCtx.Options.Limit)
|
||||
}
|
||||
if hookCtx.Options.Offset != nil {
|
||||
query = query.Offset(*hookCtx.Options.Offset)
|
||||
}
|
||||
|
||||
// Apply preloads
|
||||
for i := range hookCtx.Options.Preload {
|
||||
query = query.PreloadRelation(hookCtx.Options.Preload[i].Relation)
|
||||
}
|
||||
|
||||
// Apply columns
|
||||
if len(hookCtx.Options.Columns) > 0 {
|
||||
query = query.Column(hookCtx.Options.Columns...)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := query.ScanModel(hookCtx.Context); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to read records: %w", err)
|
||||
}
|
||||
|
||||
// Get count
|
||||
metadata = make(map[string]interface{})
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if hookCtx.Options != nil {
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
}
|
||||
count, _ := countQuery.Count(hookCtx.Context)
|
||||
metadata["total"] = count
|
||||
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
|
||||
|
||||
return hookCtx.ModelPtr, metadata, nil
|
||||
}
|
||||
|
||||
// create creates a new record
|
||||
func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
// Marshal and unmarshal data into model
|
||||
dataBytes, err := json.Marshal(hookCtx.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
|
||||
}
|
||||
|
||||
// Insert record
|
||||
query := h.db.NewInsert().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
// update updates an existing record
|
||||
func (h *Handler) update(hookCtx *HookContext) (interface{}, error) {
|
||||
// Marshal and unmarshal data into model
|
||||
dataBytes, err := json.Marshal(hookCtx.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
|
||||
}
|
||||
|
||||
// Update record
|
||||
query := h.db.NewUpdate().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return nil, fmt.Errorf("failed to update record: %w", err)
|
||||
}
|
||||
|
||||
// Fetch updated record
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
// delete deletes a record
|
||||
func (h *Handler) delete(hookCtx *HookContext) error {
|
||||
query := h.db.NewDelete().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
|
||||
// Add ID filter
|
||||
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
|
||||
|
||||
if _, err := query.Exec(hookCtx.Context); err != nil {
|
||||
return fmt.Errorf("failed to delete record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMetadata returns schema metadata for an entity
|
||||
func (h *Handler) getMetadata(hookCtx *HookContext) (interface{}, error) {
|
||||
metadata := make(map[string]interface{})
|
||||
metadata["schema"] = hookCtx.Schema
|
||||
metadata["entity"] = hookCtx.Entity
|
||||
metadata["table_name"] = hookCtx.TableName
|
||||
|
||||
// Get fields from model using reflection
|
||||
columns := reflection.GetModelColumns(hookCtx.Model)
|
||||
metadata["columns"] = columns
|
||||
metadata["primary_key"] = reflection.GetPrimaryKeyName(hookCtx.Model)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// getOperatorSQL converts filter operator to SQL operator
|
||||
func (h *Handler) getOperatorSQL(operator string) string {
|
||||
switch operator {
|
||||
case "eq":
|
||||
return "="
|
||||
case "neq":
|
||||
return "!="
|
||||
case "gt":
|
||||
return ">"
|
||||
case "gte":
|
||||
return ">="
|
||||
case "lt":
|
||||
return "<"
|
||||
case "lte":
|
||||
return "<="
|
||||
case "like":
|
||||
return "LIKE"
|
||||
case "ilike":
|
||||
return "ILIKE"
|
||||
case "in":
|
||||
return "IN"
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
743
pkg/mqttspec/handler_test.go
Normal file
743
pkg/mqttspec/handler_test.go
Normal file
@@ -0,0 +1,743 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Test model
|
||||
type TestUser struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
TenantID string `json:"tenant_id"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (TestUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// setupTestHandler creates a handler with in-memory SQLite database
|
||||
func setupTestHandler(t *testing.T) (*Handler, *gorm.DB) {
|
||||
// Create in-memory SQLite database
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Auto-migrate test model
|
||||
err = db.AutoMigrate(&TestUser{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create handler
|
||||
config := DefaultConfig()
|
||||
config.Broker.Port = 21883 // Use different port for handler tests
|
||||
|
||||
adapter := database.NewGormAdapter(db)
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", &TestUser{})
|
||||
|
||||
handler, err := NewHandlerWithDatabase(adapter, registry, WithEmbeddedBroker(config.Broker))
|
||||
require.NoError(t, err)
|
||||
|
||||
return handler, db
|
||||
}
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
assert.NotNil(t, handler)
|
||||
assert.NotNil(t, handler.db)
|
||||
assert.NotNil(t, handler.registry)
|
||||
assert.NotNil(t, handler.hooks)
|
||||
assert.NotNil(t, handler.clientManager)
|
||||
assert.NotNil(t, handler.subscriptionManager)
|
||||
assert.NotNil(t, handler.broker)
|
||||
assert.NotNil(t, handler.config)
|
||||
}
|
||||
|
||||
func TestHandler_StartShutdown(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, handler.started)
|
||||
|
||||
// Shutdown handler
|
||||
err = handler.Shutdown()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, handler.started)
|
||||
}
|
||||
|
||||
func TestHandler_HandleRead_Single(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request message
|
||||
msg := &Message{
|
||||
ID: "msg-1",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
|
||||
// Note: In a full integration test, we would verify the response was published
|
||||
// to the correct MQTT topic. Here we're just testing that the handler doesn't error.
|
||||
}
|
||||
|
||||
func TestHandler_HandleRead_Multiple(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
users := []TestUser{
|
||||
{ID: 1, Name: "User 1", Email: "user1@example.com", Status: "active"},
|
||||
{ID: 2, Name: "User 2", Email: "user2@example.com", Status: "active"},
|
||||
{ID: 3, Name: "User 3", Email: "user3@example.com", Status: "inactive"},
|
||||
}
|
||||
for _, user := range users {
|
||||
db.Create(&user)
|
||||
}
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request with filter
|
||||
msg := &Message{
|
||||
ID: "msg-2",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
}
|
||||
|
||||
func TestHandler_HandleCreate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler to initialize broker
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create request data
|
||||
newUser := map[string]interface{}{
|
||||
"name": "New User",
|
||||
"email": "new@example.com",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
// Create create request message
|
||||
msg := &Message{
|
||||
ID: "msg-3",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle create
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
|
||||
// Verify user was created in database
|
||||
var user TestUser
|
||||
result := db.Where("email = ?", "new@example.com").First(&user)
|
||||
assert.NoError(t, result.Error)
|
||||
assert.Equal(t, "New User", user.Name)
|
||||
}
|
||||
|
||||
func TestHandler_HandleUpdate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "Original Name",
|
||||
Email: "original@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Update data
|
||||
updateData := map[string]interface{}{
|
||||
"name": "Updated Name",
|
||||
}
|
||||
|
||||
// Create update request message
|
||||
msg := &Message{
|
||||
ID: "msg-4",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationUpdate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: updateData,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Data: updateData,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle update
|
||||
handler.handleUpdate(client, msg, hookCtx)
|
||||
|
||||
// Verify user was updated
|
||||
var updatedUser TestUser
|
||||
db.First(&updatedUser, 1)
|
||||
assert.Equal(t, "Updated Name", updatedUser.Name)
|
||||
}
|
||||
|
||||
func TestHandler_HandleDelete(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data
|
||||
user := &TestUser{
|
||||
ID: 1,
|
||||
Name: "To Delete",
|
||||
Email: "delete@example.com",
|
||||
Status: "active",
|
||||
}
|
||||
db.Create(user)
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create delete request message
|
||||
msg := &Message{
|
||||
ID: "msg-5",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationDelete,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
ID: "1",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle delete
|
||||
handler.handleDelete(client, msg, hookCtx)
|
||||
|
||||
// Verify user was deleted
|
||||
var deletedUser TestUser
|
||||
result := db.First(&deletedUser, 1)
|
||||
assert.Error(t, result.Error)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, result.Error)
|
||||
}
|
||||
|
||||
func TestHandler_HandleSubscribe(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create subscribe message
|
||||
msg := &Message{
|
||||
ID: "msg-6",
|
||||
Type: MessageTypeSubscription,
|
||||
Operation: OperationSubscribe,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Handle subscribe
|
||||
handler.handleSubscribe(client, msg)
|
||||
|
||||
// Verify subscription was created
|
||||
subscriptions := handler.subscriptionManager.GetSubscriptionsByEntity("public", "users")
|
||||
assert.Len(t, subscriptions, 1)
|
||||
assert.Equal(t, client.ID, subscriptions[0].ConnectionID)
|
||||
}
|
||||
|
||||
func TestHandler_HandleUnsubscribe(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create subscription using Subscribe method
|
||||
sub := handler.subscriptionManager.Subscribe("sub-1", client.ID, "public", "users", &common.RequestOptions{})
|
||||
client.AddSubscription(sub)
|
||||
|
||||
// Create unsubscribe message with subscription ID in Data
|
||||
msg := &Message{
|
||||
ID: "msg-7",
|
||||
Type: MessageTypeSubscription,
|
||||
Operation: OperationUnsubscribe,
|
||||
Data: map[string]interface{}{"subscription_id": "sub-1"},
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Handle unsubscribe
|
||||
handler.handleUnsubscribe(client, msg)
|
||||
|
||||
// Verify subscription was removed
|
||||
_, exists := handler.subscriptionManager.GetSubscription("sub-1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestHandler_NotifySubscribers(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock clients
|
||||
client1 := handler.clientManager.Register("client-1", "user1", handler)
|
||||
client2 := handler.clientManager.Register("client-2", "user2", handler)
|
||||
|
||||
// Create subscriptions
|
||||
opts1 := &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
},
|
||||
}
|
||||
sub1 := handler.subscriptionManager.Subscribe("sub-1", client1.ID, "public", "users", opts1)
|
||||
client1.AddSubscription(sub1)
|
||||
|
||||
opts2 := &common.RequestOptions{
|
||||
Filters: []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "inactive"},
|
||||
},
|
||||
}
|
||||
sub2 := handler.subscriptionManager.Subscribe("sub-2", client2.ID, "public", "users", opts2)
|
||||
client2.AddSubscription(sub2)
|
||||
|
||||
// Notify subscribers with active user
|
||||
activeUser := map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "Active User",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
// This should notify sub-1 only
|
||||
handler.notifySubscribers("public", "users", OperationCreate, activeUser)
|
||||
|
||||
// Note: In a full integration test, we would verify that the notification
|
||||
// was published to the correct MQTT topic. Here we're just testing that
|
||||
// the handler doesn't error and finds the correct subscriptions.
|
||||
}
|
||||
|
||||
func TestHandler_Hooks_BeforeRead(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Insert test data with different tenants
|
||||
users := []TestUser{
|
||||
{ID: 1, Name: "User 1", TenantID: "tenant-a", Status: "active"},
|
||||
{ID: 2, Name: "User 2", TenantID: "tenant-b", Status: "active"},
|
||||
{ID: 3, Name: "User 3", TenantID: "tenant-a", Status: "active"},
|
||||
}
|
||||
for _, user := range users {
|
||||
db.Create(&user)
|
||||
}
|
||||
|
||||
// Register hook to filter by tenant
|
||||
handler.Hooks().Register(BeforeRead, func(ctx *HookContext) error {
|
||||
// Auto-inject tenant filter
|
||||
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
|
||||
Column: "tenant_id",
|
||||
Operator: "eq",
|
||||
Value: "tenant-a",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create read request (no tenant filter)
|
||||
msg := &Message{
|
||||
ID: "msg-8",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Create hook context
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle read
|
||||
handler.handleRead(client, msg, hookCtx)
|
||||
|
||||
// The hook should have injected the tenant filter
|
||||
// In a full test, we would verify only tenant-a users were returned
|
||||
}
|
||||
|
||||
func TestHandler_Hooks_BeforeCreate(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Register hook to set default values
|
||||
handler.Hooks().Register(BeforeCreate, func(ctx *HookContext) error {
|
||||
// Auto-set tenant_id
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["tenant_id"] = "auto-tenant"
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create mock client
|
||||
client := NewClient("test-client", "test-user", handler)
|
||||
|
||||
// Create user without tenant_id
|
||||
newUser := map[string]interface{}{
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
ID: "msg-9",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
// Handle create
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
|
||||
// Verify tenant_id was auto-set
|
||||
var user TestUser
|
||||
db.Where("email = ?", "test@example.com").First(&user)
|
||||
assert.Equal(t, "auto-tenant", user.TenantID)
|
||||
}
|
||||
|
||||
func TestHandler_ConcurrentRequests(t *testing.T) {
|
||||
handler, db := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Create multiple clients
|
||||
var wg sync.WaitGroup
|
||||
numClients := 10
|
||||
|
||||
for i := 0; i < numClients; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
client := NewClient(fmt.Sprintf("client-%d", id), fmt.Sprintf("user%d", id), handler)
|
||||
|
||||
// Create user
|
||||
newUser := map[string]interface{}{
|
||||
"name": fmt.Sprintf("User %d", id),
|
||||
"email": fmt.Sprintf("user%d@example.com", id),
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
ID: fmt.Sprintf("msg-%d", id),
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationCreate,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: nil,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Data: newUser,
|
||||
Options: msg.Options,
|
||||
Metadata: map[string]interface{}{"mqtt_client": client},
|
||||
}
|
||||
|
||||
handler.handleCreate(client, msg, hookCtx)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all users were created
|
||||
var count int64
|
||||
db.Model(&TestUser{}).Count(&count)
|
||||
assert.Equal(t, int64(numClients), count)
|
||||
}
|
||||
|
||||
func TestHandler_TopicHelpers(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
clientID := "test-client"
|
||||
subscriptionID := "sub-123"
|
||||
|
||||
requestTopic := handler.getRequestTopic(clientID)
|
||||
assert.Equal(t, "spec/test-client/request", requestTopic)
|
||||
|
||||
responseTopic := handler.getResponseTopic(clientID)
|
||||
assert.Equal(t, "spec/test-client/response", responseTopic)
|
||||
|
||||
notifyTopic := handler.getNotifyTopic(clientID, subscriptionID)
|
||||
assert.Equal(t, "spec/test-client/notify/sub-123", notifyTopic)
|
||||
}
|
||||
|
||||
func TestHandler_SendResponse(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Test data
|
||||
clientID := "test-client"
|
||||
msgID := "msg-123"
|
||||
data := map[string]interface{}{"id": 1, "name": "Test"}
|
||||
metadata := map[string]interface{}{"total": 1}
|
||||
|
||||
// Send response (should not error)
|
||||
handler.sendResponse(clientID, msgID, data, metadata)
|
||||
}
|
||||
|
||||
func TestHandler_SendError(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Test error
|
||||
clientID := "test-client"
|
||||
msgID := "msg-123"
|
||||
code := "test_error"
|
||||
message := "Test error message"
|
||||
|
||||
// Send error (should not error)
|
||||
handler.sendError(clientID, msgID, code, message)
|
||||
}
|
||||
|
||||
// extractClientID extracts the client ID from a topic like spec/{client_id}/request
|
||||
func extractClientID(topic string) string {
|
||||
parts := strings.Split(topic, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[len(parts)-2]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestHandler_ExtractClientID(t *testing.T) {
|
||||
tests := []struct {
|
||||
topic string
|
||||
expected string
|
||||
}{
|
||||
{"spec/client-123/request", "client-123"},
|
||||
{"spec/abc-xyz/request", "abc-xyz"},
|
||||
{"spec/test/request", "test"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := extractClientID(tt.topic)
|
||||
assert.Equal(t, tt.expected, result, "topic: %s", tt.topic)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_HandleIncomingMessage_InvalidJSON(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Invalid JSON payload
|
||||
payload := []byte("{invalid json")
|
||||
|
||||
// Should not panic
|
||||
handler.handleIncomingMessage("spec/test-client/request", payload)
|
||||
}
|
||||
|
||||
func TestHandler_HandleIncomingMessage_ValidMessage(t *testing.T) {
|
||||
handler, _ := setupTestHandler(t)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Start handler
|
||||
err := handler.Start()
|
||||
require.NoError(t, err)
|
||||
defer handler.Shutdown()
|
||||
|
||||
// Valid message
|
||||
msg := &Message{
|
||||
ID: "msg-1",
|
||||
Type: MessageTypeRequest,
|
||||
Operation: OperationRead,
|
||||
Schema: "public",
|
||||
Entity: "users",
|
||||
Options: &common.RequestOptions{},
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(msg)
|
||||
|
||||
// Should not panic or error
|
||||
handler.handleIncomingMessage("spec/test-client/request", payload)
|
||||
}
|
||||
51
pkg/mqttspec/hooks.go
Normal file
51
pkg/mqttspec/hooks.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
|
||||
)
|
||||
|
||||
// Hook types - aliases to websocketspec for lifecycle hook consistency
|
||||
type (
|
||||
// HookType defines the type of lifecycle hook
|
||||
HookType = websocketspec.HookType
|
||||
|
||||
// HookFunc is a function that executes during a lifecycle hook
|
||||
HookFunc = websocketspec.HookFunc
|
||||
|
||||
// HookContext contains all context for hook execution
|
||||
// Note: For MQTT, the Client is stored in Metadata["mqtt_client"]
|
||||
HookContext = websocketspec.HookContext
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
HookRegistry = websocketspec.HookRegistry
|
||||
)
|
||||
|
||||
// Hook type constants - all 12 lifecycle hooks
|
||||
const (
|
||||
// CRUD operation hooks
|
||||
BeforeRead = websocketspec.BeforeRead
|
||||
AfterRead = websocketspec.AfterRead
|
||||
BeforeCreate = websocketspec.BeforeCreate
|
||||
AfterCreate = websocketspec.AfterCreate
|
||||
BeforeUpdate = websocketspec.BeforeUpdate
|
||||
AfterUpdate = websocketspec.AfterUpdate
|
||||
BeforeDelete = websocketspec.BeforeDelete
|
||||
AfterDelete = websocketspec.AfterDelete
|
||||
|
||||
// Subscription hooks
|
||||
BeforeSubscribe = websocketspec.BeforeSubscribe
|
||||
AfterSubscribe = websocketspec.AfterSubscribe
|
||||
BeforeUnsubscribe = websocketspec.BeforeUnsubscribe
|
||||
AfterUnsubscribe = websocketspec.AfterUnsubscribe
|
||||
|
||||
// Connection hooks
|
||||
BeforeConnect = websocketspec.BeforeConnect
|
||||
AfterConnect = websocketspec.AfterConnect
|
||||
BeforeDisconnect = websocketspec.BeforeDisconnect
|
||||
AfterDisconnect = websocketspec.AfterDisconnect
|
||||
)
|
||||
|
||||
// NewHookRegistry creates a new hook registry
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return websocketspec.NewHookRegistry()
|
||||
}
|
||||
63
pkg/mqttspec/message.go
Normal file
63
pkg/mqttspec/message.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
|
||||
)
|
||||
|
||||
// Message types - aliases to websocketspec for protocol consistency
|
||||
type (
|
||||
// Message represents an MQTT message (identical to WebSocket message protocol)
|
||||
Message = websocketspec.Message
|
||||
|
||||
// MessageType defines the type of message
|
||||
MessageType = websocketspec.MessageType
|
||||
|
||||
// OperationType defines the operation to perform
|
||||
OperationType = websocketspec.OperationType
|
||||
|
||||
// ResponseMessage is sent back to clients after processing requests
|
||||
ResponseMessage = websocketspec.ResponseMessage
|
||||
|
||||
// NotificationMessage is sent to subscribers when data changes
|
||||
NotificationMessage = websocketspec.NotificationMessage
|
||||
|
||||
// ErrorInfo contains error details
|
||||
ErrorInfo = websocketspec.ErrorInfo
|
||||
)
|
||||
|
||||
// Message type constants
|
||||
const (
|
||||
MessageTypeRequest = websocketspec.MessageTypeRequest
|
||||
MessageTypeResponse = websocketspec.MessageTypeResponse
|
||||
MessageTypeNotification = websocketspec.MessageTypeNotification
|
||||
MessageTypeSubscription = websocketspec.MessageTypeSubscription
|
||||
MessageTypeError = websocketspec.MessageTypeError
|
||||
MessageTypePing = websocketspec.MessageTypePing
|
||||
MessageTypePong = websocketspec.MessageTypePong
|
||||
)
|
||||
|
||||
// Operation type constants
|
||||
const (
|
||||
OperationRead = websocketspec.OperationRead
|
||||
OperationCreate = websocketspec.OperationCreate
|
||||
OperationUpdate = websocketspec.OperationUpdate
|
||||
OperationDelete = websocketspec.OperationDelete
|
||||
OperationSubscribe = websocketspec.OperationSubscribe
|
||||
OperationUnsubscribe = websocketspec.OperationUnsubscribe
|
||||
OperationMeta = websocketspec.OperationMeta
|
||||
)
|
||||
|
||||
// Helper functions from websocketspec
|
||||
var (
|
||||
// NewResponseMessage creates a new response message
|
||||
NewResponseMessage = websocketspec.NewResponseMessage
|
||||
|
||||
// NewErrorResponse creates an error response
|
||||
NewErrorResponse = websocketspec.NewErrorResponse
|
||||
|
||||
// NewNotificationMessage creates a notification message
|
||||
NewNotificationMessage = websocketspec.NewNotificationMessage
|
||||
|
||||
// ParseMessage parses a JSON message into a Message struct
|
||||
ParseMessage = websocketspec.ParseMessage
|
||||
)
|
||||
104
pkg/mqttspec/mqttspec.go
Normal file
104
pkg/mqttspec/mqttspec.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// NewHandlerWithGORM creates an MQTT handler with GORM database adapter
|
||||
func NewHandlerWithGORM(db *gorm.DB, opts ...Option) (*Handler, error) {
|
||||
adapter := database.NewGormAdapter(db)
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
return NewHandlerWithDatabase(adapter, registry, opts...)
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates an MQTT handler with Bun database adapter
|
||||
func NewHandlerWithBun(db *bun.DB, opts ...Option) (*Handler, error) {
|
||||
adapter := database.NewBunAdapter(db)
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
return NewHandlerWithDatabase(adapter, registry, opts...)
|
||||
}
|
||||
|
||||
// NewHandlerWithDatabase creates an MQTT handler with a custom database adapter
|
||||
func NewHandlerWithDatabase(db common.Database, registry common.ModelRegistry, opts ...Option) (*Handler, error) {
|
||||
// Start with default configuration
|
||||
config := DefaultConfig()
|
||||
|
||||
// Create handler with basic initialization
|
||||
// Note: broker and clientManager will be initialized after options are applied
|
||||
handler, err := NewHandler(db, registry, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply functional options
|
||||
for _, opt := range opts {
|
||||
if err := opt(handler); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Reinitialize broker based on final config (after options)
|
||||
if handler.config.BrokerMode == BrokerModeEmbedded {
|
||||
handler.broker = NewEmbeddedBroker(handler.config.Broker, handler.clientManager)
|
||||
} else {
|
||||
handler.broker = NewExternalBrokerClient(handler.config.ExternalBroker, handler.clientManager)
|
||||
}
|
||||
|
||||
// Set handler reference in broker
|
||||
handler.broker.SetHandler(handler)
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring the handler
|
||||
type Option func(*Handler) error
|
||||
|
||||
// WithEmbeddedBroker configures the handler to use an embedded MQTT broker
|
||||
func WithEmbeddedBroker(config BrokerConfig) Option {
|
||||
return func(h *Handler) error {
|
||||
h.config.BrokerMode = BrokerModeEmbedded
|
||||
h.config.Broker = config
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithExternalBroker configures the handler to connect to an external MQTT broker
|
||||
func WithExternalBroker(config ExternalBrokerConfig) Option {
|
||||
return func(h *Handler) error {
|
||||
h.config.BrokerMode = BrokerModeExternal
|
||||
h.config.ExternalBroker = config
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHooks sets a pre-configured hook registry
|
||||
func WithHooks(hooks *HookRegistry) Option {
|
||||
return func(h *Handler) error {
|
||||
h.hooks = hooks
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTopicPrefix sets a custom topic prefix (default: "spec")
|
||||
func WithTopicPrefix(prefix string) Option {
|
||||
return func(h *Handler) error {
|
||||
h.config.Topics.Prefix = prefix
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithQoS sets custom QoS levels for different message types
|
||||
func WithQoS(request, response, notification byte) Option {
|
||||
return func(h *Handler) error {
|
||||
h.config.QoS.Request = request
|
||||
h.config.QoS.Response = response
|
||||
h.config.QoS.Notification = notification
|
||||
return nil
|
||||
}
|
||||
}
|
||||
21
pkg/mqttspec/subscription.go
Normal file
21
pkg/mqttspec/subscription.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
|
||||
)
|
||||
|
||||
// Subscription types - aliases to websocketspec for subscription management
|
||||
type (
|
||||
// Subscription represents an active subscription to entity changes
|
||||
// The key difference for MQTT: notifications are delivered via MQTT publish
|
||||
// to spec/{client_id}/notify/{subscription_id} instead of WebSocket send
|
||||
Subscription = websocketspec.Subscription
|
||||
|
||||
// SubscriptionManager manages all active subscriptions
|
||||
SubscriptionManager = websocketspec.SubscriptionManager
|
||||
)
|
||||
|
||||
// NewSubscriptionManager creates a new subscription manager
|
||||
func NewSubscriptionManager() *SubscriptionManager {
|
||||
return websocketspec.NewSubscriptionManager()
|
||||
}
|
||||
@@ -273,25 +273,151 @@ handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
})
|
||||
```
|
||||
|
||||
## Using with Swagger UI
|
||||
## Using the Built-in UI Handler
|
||||
|
||||
You can serve the generated OpenAPI spec with Swagger UI:
|
||||
The package includes a built-in UI handler that serves popular OpenAPI visualization tools. No need to download or manage static files - everything is served from CDN.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/openapi"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Setup your API routes and OpenAPI generator...
|
||||
// (see examples above)
|
||||
|
||||
// Add the UI handler - defaults to Swagger UI
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.SwaggerUI,
|
||||
SpecURL: "/openapi",
|
||||
Title: "My API Documentation",
|
||||
})
|
||||
|
||||
// Now visit http://localhost:8080/docs
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
### Supported UI Frameworks
|
||||
|
||||
The handler supports four popular OpenAPI UI frameworks:
|
||||
|
||||
#### 1. Swagger UI (Default)
|
||||
The most widely used OpenAPI UI with excellent compatibility and features.
|
||||
|
||||
```go
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.SwaggerUI,
|
||||
Theme: "dark", // optional: "light" or "dark"
|
||||
})
|
||||
```
|
||||
|
||||
#### 2. RapiDoc
|
||||
Modern, customizable, and feature-rich OpenAPI UI.
|
||||
|
||||
```go
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.RapiDoc,
|
||||
Theme: "dark",
|
||||
})
|
||||
```
|
||||
|
||||
#### 3. Redoc
|
||||
Clean, responsive documentation with great UX.
|
||||
|
||||
```go
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.Redoc,
|
||||
})
|
||||
```
|
||||
|
||||
#### 4. Scalar
|
||||
Modern and sleek OpenAPI documentation.
|
||||
|
||||
```go
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.Scalar,
|
||||
Theme: "dark",
|
||||
})
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```go
|
||||
type UIConfig struct {
|
||||
UIType UIType // SwaggerUI, RapiDoc, Redoc, or Scalar
|
||||
SpecURL string // URL to OpenAPI spec (default: "/openapi")
|
||||
Title string // Page title (default: "API Documentation")
|
||||
FaviconURL string // Custom favicon URL (optional)
|
||||
CustomCSS string // Custom CSS to inject (optional)
|
||||
Theme string // "light" or "dark" (support varies by UI)
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Styling Example
|
||||
|
||||
```go
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.SwaggerUI,
|
||||
Title: "Acme Corp API",
|
||||
CustomCSS: `
|
||||
.swagger-ui .topbar {
|
||||
background-color: #1976d2;
|
||||
}
|
||||
.swagger-ui .info .title {
|
||||
color: #1976d2;
|
||||
}
|
||||
`,
|
||||
})
|
||||
```
|
||||
|
||||
### Using Multiple UIs
|
||||
|
||||
You can serve different UIs at different paths:
|
||||
|
||||
```go
|
||||
// Swagger UI at /docs
|
||||
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||
UIType: openapi.SwaggerUI,
|
||||
})
|
||||
|
||||
// Redoc at /redoc
|
||||
openapi.SetupUIRoute(router, "/redoc", openapi.UIConfig{
|
||||
UIType: openapi.Redoc,
|
||||
})
|
||||
|
||||
// RapiDoc at /api-docs
|
||||
openapi.SetupUIRoute(router, "/api-docs", openapi.UIConfig{
|
||||
UIType: openapi.RapiDoc,
|
||||
})
|
||||
```
|
||||
|
||||
### Manual Handler Usage
|
||||
|
||||
If you need more control, use the handler directly:
|
||||
|
||||
```go
|
||||
handler := openapi.UIHandler(openapi.UIConfig{
|
||||
UIType: openapi.SwaggerUI,
|
||||
SpecURL: "/api/openapi.json",
|
||||
})
|
||||
|
||||
router.Handle("/documentation", handler)
|
||||
```
|
||||
|
||||
## Using with External Swagger UI
|
||||
|
||||
Alternatively, you can use an external Swagger UI instance:
|
||||
|
||||
1. Get the spec from `/openapi`
|
||||
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
||||
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
||||
|
||||
Example with self-hosted Swagger UI:
|
||||
|
||||
```go
|
||||
// Serve Swagger UI static files
|
||||
router.PathPrefix("/swagger/").Handler(
|
||||
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
|
||||
)
|
||||
|
||||
// Configure Swagger UI to use /openapi
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
You can test the OpenAPI endpoint:
|
||||
|
||||
@@ -183,6 +183,69 @@ func ExampleWithFuncSpec() {
|
||||
_ = generatorFunc
|
||||
}
|
||||
|
||||
// ExampleWithUIHandler shows how to serve OpenAPI documentation with a web UI
|
||||
func ExampleWithUIHandler(db *gorm.DB) {
|
||||
// Create handler and configure OpenAPI generator
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My API",
|
||||
Description: "API documentation with interactive UI",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
})
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Add UI handlers for different frameworks
|
||||
// Swagger UI at /docs (most popular)
|
||||
SetupUIRoute(router, "/docs", UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
SpecURL: "/openapi",
|
||||
Title: "My API - Swagger UI",
|
||||
Theme: "light",
|
||||
})
|
||||
|
||||
// RapiDoc at /rapidoc (modern alternative)
|
||||
SetupUIRoute(router, "/rapidoc", UIConfig{
|
||||
UIType: RapiDoc,
|
||||
SpecURL: "/openapi",
|
||||
Title: "My API - RapiDoc",
|
||||
})
|
||||
|
||||
// Redoc at /redoc (clean and responsive)
|
||||
SetupUIRoute(router, "/redoc", UIConfig{
|
||||
UIType: Redoc,
|
||||
SpecURL: "/openapi",
|
||||
Title: "My API - Redoc",
|
||||
})
|
||||
|
||||
// Scalar at /scalar (modern and sleek)
|
||||
SetupUIRoute(router, "/scalar", UIConfig{
|
||||
UIType: Scalar,
|
||||
SpecURL: "/openapi",
|
||||
Title: "My API - Scalar",
|
||||
Theme: "dark",
|
||||
})
|
||||
|
||||
// Now you can access:
|
||||
// http://localhost:8080/docs - Swagger UI
|
||||
// http://localhost:8080/rapidoc - RapiDoc
|
||||
// http://localhost:8080/redoc - Redoc
|
||||
// http://localhost:8080/scalar - Scalar
|
||||
// http://localhost:8080/openapi - Raw OpenAPI JSON
|
||||
|
||||
_ = router
|
||||
}
|
||||
|
||||
// ExampleCustomization shows advanced customization options
|
||||
func ExampleCustomization() {
|
||||
// Create registry and register models with descriptions using struct tags
|
||||
|
||||
294
pkg/openapi/ui_handler.go
Normal file
294
pkg/openapi/ui_handler.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// UIType represents the type of OpenAPI UI to serve
|
||||
type UIType string
|
||||
|
||||
const (
|
||||
// SwaggerUI is the most popular OpenAPI UI
|
||||
SwaggerUI UIType = "swagger-ui"
|
||||
// RapiDoc is a modern, customizable OpenAPI UI
|
||||
RapiDoc UIType = "rapidoc"
|
||||
// Redoc is a clean, responsive OpenAPI UI
|
||||
Redoc UIType = "redoc"
|
||||
// Scalar is a modern and sleek OpenAPI UI
|
||||
Scalar UIType = "scalar"
|
||||
)
|
||||
|
||||
// UIConfig holds configuration for the OpenAPI UI handler
|
||||
type UIConfig struct {
|
||||
// UIType specifies which UI framework to use (default: SwaggerUI)
|
||||
UIType UIType
|
||||
// SpecURL is the URL to the OpenAPI spec JSON (default: "/openapi")
|
||||
SpecURL string
|
||||
// Title is the page title (default: "API Documentation")
|
||||
Title string
|
||||
// FaviconURL is the URL to the favicon (optional)
|
||||
FaviconURL string
|
||||
// CustomCSS allows injecting custom CSS (optional)
|
||||
CustomCSS string
|
||||
// Theme for the UI (light/dark, depends on UI type)
|
||||
Theme string
|
||||
}
|
||||
|
||||
// UIHandler creates an HTTP handler that serves an OpenAPI UI
|
||||
func UIHandler(config UIConfig) http.HandlerFunc {
|
||||
// Set defaults
|
||||
if config.UIType == "" {
|
||||
config.UIType = SwaggerUI
|
||||
}
|
||||
if config.SpecURL == "" {
|
||||
config.SpecURL = "/openapi"
|
||||
}
|
||||
if config.Title == "" {
|
||||
config.Title = "API Documentation"
|
||||
}
|
||||
if config.Theme == "" {
|
||||
config.Theme = "light"
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var htmlContent string
|
||||
var err error
|
||||
|
||||
switch config.UIType {
|
||||
case SwaggerUI:
|
||||
htmlContent, err = generateSwaggerUI(config)
|
||||
case RapiDoc:
|
||||
htmlContent, err = generateRapiDoc(config)
|
||||
case Redoc:
|
||||
htmlContent, err = generateRedoc(config)
|
||||
case Scalar:
|
||||
htmlContent, err = generateScalar(config)
|
||||
default:
|
||||
http.Error(w, "Unsupported UI type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to generate UI: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write([]byte(htmlContent))
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to write response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// templateData wraps UIConfig to properly handle CSS in templates
|
||||
type templateData struct {
|
||||
UIConfig
|
||||
SafeCustomCSS template.CSS
|
||||
}
|
||||
|
||||
// generateSwaggerUI generates the HTML for Swagger UI
|
||||
func generateSwaggerUI(config UIConfig) (string, error) {
|
||||
tmpl := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{{.Title}}</title>
|
||||
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
|
||||
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||
<style>
|
||||
html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; }
|
||||
*, *:before, *:after { box-sizing: inherit; }
|
||||
body { margin: 0; padding: 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-standalone-preset.js"></script>
|
||||
<script>
|
||||
window.onload = function() {
|
||||
const ui = SwaggerUIBundle({
|
||||
url: "{{.SpecURL}}",
|
||||
dom_id: '#swagger-ui',
|
||||
deepLinking: true,
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIStandalonePreset
|
||||
],
|
||||
plugins: [
|
||||
SwaggerUIBundle.plugins.DownloadUrl
|
||||
],
|
||||
layout: "StandaloneLayout",
|
||||
{{if eq .Theme "dark"}}
|
||||
syntaxHighlight: {
|
||||
activate: true,
|
||||
theme: "monokai"
|
||||
}
|
||||
{{end}}
|
||||
});
|
||||
window.ui = ui;
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
t, err := template.New("swagger").Parse(tmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := templateData{
|
||||
UIConfig: config,
|
||||
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// generateRapiDoc generates the HTML for RapiDoc
|
||||
func generateRapiDoc(config UIConfig) (string, error) {
|
||||
theme := "light"
|
||||
if config.Theme == "dark" {
|
||||
theme = "dark"
|
||||
}
|
||||
|
||||
tmpl := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{{.Title}}</title>
|
||||
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
|
||||
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||
</head>
|
||||
<body>
|
||||
<rapi-doc
|
||||
spec-url="{{.SpecURL}}"
|
||||
theme="` + theme + `"
|
||||
render-style="read"
|
||||
show-header="true"
|
||||
show-info="true"
|
||||
allow-try="true"
|
||||
allow-server-selection="true"
|
||||
allow-authentication="true"
|
||||
api-key-name="Authorization"
|
||||
api-key-location="header"
|
||||
></rapi-doc>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
t, err := template.New("rapidoc").Parse(tmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := templateData{
|
||||
UIConfig: config,
|
||||
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// generateRedoc generates the HTML for Redoc
|
||||
func generateRedoc(config UIConfig) (string, error) {
|
||||
tmpl := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{{.Title}}</title>
|
||||
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||
<style>
|
||||
body { margin: 0; padding: 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<redoc spec-url="{{.SpecURL}}" {{if eq .Theme "dark"}}theme='{"colors": {"primary": {"main": "#dd5522"}}}'{{end}}></redoc>
|
||||
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
t, err := template.New("redoc").Parse(tmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := templateData{
|
||||
UIConfig: config,
|
||||
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// generateScalar generates the HTML for Scalar
|
||||
func generateScalar(config UIConfig) (string, error) {
|
||||
tmpl := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{{.Title}}</title>
|
||||
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||
<style>
|
||||
body { margin: 0; padding: 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<script id="api-reference" data-url="{{.SpecURL}}" {{if eq .Theme "dark"}}data-theme="dark"{{end}}></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
t, err := template.New("scalar").Parse(tmpl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := templateData{
|
||||
UIConfig: config,
|
||||
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// SetupUIRoute adds the OpenAPI UI route to a mux router
|
||||
// This is a convenience function for the most common use case
|
||||
func SetupUIRoute(router *mux.Router, path string, config UIConfig) {
|
||||
router.Handle(path, UIHandler(config))
|
||||
}
|
||||
308
pkg/openapi/ui_handler_test.go
Normal file
308
pkg/openapi/ui_handler_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func TestUIHandler_SwaggerUI(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
SpecURL: "/openapi",
|
||||
Title: "Test API Docs",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// Check for Swagger UI specific content
|
||||
if !strings.Contains(body, "swagger-ui") {
|
||||
t.Error("Expected Swagger UI content")
|
||||
}
|
||||
if !strings.Contains(body, "SwaggerUIBundle") {
|
||||
t.Error("Expected SwaggerUIBundle script")
|
||||
}
|
||||
if !strings.Contains(body, config.Title) {
|
||||
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||
}
|
||||
if !strings.Contains(body, config.SpecURL) {
|
||||
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||
}
|
||||
if !strings.Contains(body, "swagger-ui-dist") {
|
||||
t.Error("Expected Swagger UI CDN link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_RapiDoc(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: RapiDoc,
|
||||
SpecURL: "/api/spec",
|
||||
Title: "RapiDoc Test",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// Check for RapiDoc specific content
|
||||
if !strings.Contains(body, "rapi-doc") {
|
||||
t.Error("Expected rapi-doc element")
|
||||
}
|
||||
if !strings.Contains(body, "rapidoc-min.js") {
|
||||
t.Error("Expected RapiDoc script")
|
||||
}
|
||||
if !strings.Contains(body, config.Title) {
|
||||
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||
}
|
||||
if !strings.Contains(body, config.SpecURL) {
|
||||
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_Redoc(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: Redoc,
|
||||
SpecURL: "/spec.json",
|
||||
Title: "Redoc Test",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// Check for Redoc specific content
|
||||
if !strings.Contains(body, "<redoc") {
|
||||
t.Error("Expected redoc element")
|
||||
}
|
||||
if !strings.Contains(body, "redoc.standalone.js") {
|
||||
t.Error("Expected Redoc script")
|
||||
}
|
||||
if !strings.Contains(body, config.Title) {
|
||||
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||
}
|
||||
if !strings.Contains(body, config.SpecURL) {
|
||||
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_Scalar(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: Scalar,
|
||||
SpecURL: "/openapi.json",
|
||||
Title: "Scalar Test",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// Check for Scalar specific content
|
||||
if !strings.Contains(body, "api-reference") {
|
||||
t.Error("Expected api-reference element")
|
||||
}
|
||||
if !strings.Contains(body, "@scalar/api-reference") {
|
||||
t.Error("Expected Scalar script")
|
||||
}
|
||||
if !strings.Contains(body, config.Title) {
|
||||
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||
}
|
||||
if !strings.Contains(body, config.SpecURL) {
|
||||
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_DefaultValues(t *testing.T) {
|
||||
// Test with empty config to check defaults
|
||||
config := UIConfig{}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// Should default to Swagger UI
|
||||
if !strings.Contains(body, "swagger-ui") {
|
||||
t.Error("Expected default to Swagger UI")
|
||||
}
|
||||
|
||||
// Should default to /openapi spec URL
|
||||
if !strings.Contains(body, "/openapi") {
|
||||
t.Error("Expected default spec URL '/openapi'")
|
||||
}
|
||||
|
||||
// Should default to "API Documentation" title
|
||||
if !strings.Contains(body, "API Documentation") {
|
||||
t.Error("Expected default title 'API Documentation'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_CustomCSS(t *testing.T) {
|
||||
customCSS := ".custom-class { color: red; }"
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
CustomCSS: customCSS,
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if !strings.Contains(body, customCSS) {
|
||||
t.Errorf("Expected custom CSS to be included. Body:\n%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_Favicon(t *testing.T) {
|
||||
faviconURL := "https://example.com/favicon.ico"
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
FaviconURL: faviconURL,
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if !strings.Contains(body, faviconURL) {
|
||||
t.Error("Expected favicon URL to be included")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_DarkTheme(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
Theme: "dark",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// SwaggerUI uses monokai theme for dark mode
|
||||
if !strings.Contains(body, "monokai") {
|
||||
t.Error("Expected dark theme configuration for Swagger UI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_InvalidUIType(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: "invalid-ui-type",
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for invalid UI type, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIHandler_ContentType(t *testing.T) {
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
}
|
||||
|
||||
handler := UIHandler(config)
|
||||
req := httptest.NewRequest("GET", "/docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
t.Errorf("Expected Content-Type to contain 'text/html', got '%s'", contentType)
|
||||
}
|
||||
if !strings.Contains(contentType, "charset=utf-8") {
|
||||
t.Errorf("Expected Content-Type to contain 'charset=utf-8', got '%s'", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupUIRoute(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
|
||||
config := UIConfig{
|
||||
UIType: SwaggerUI,
|
||||
}
|
||||
|
||||
SetupUIRoute(router, "/api-docs", config)
|
||||
|
||||
// Test that the route was added and works
|
||||
req := httptest.NewRequest("GET", "/api-docs", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Verify it returns HTML
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "swagger-ui") {
|
||||
t.Error("Expected Swagger UI content")
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
@@ -897,6 +899,368 @@ func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||
return currentModel
|
||||
}
|
||||
|
||||
// MapToStruct populates a struct from a map while preserving custom types
|
||||
// It uses reflection to set struct fields based on map keys, matching by:
|
||||
// 1. Bun tag column name
|
||||
// 2. Gorm tag column name
|
||||
// 3. JSON tag name
|
||||
// 4. Field name (case-insensitive)
|
||||
// This preserves custom types that implement driver.Valuer like SqlJSONB
|
||||
func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||
if dataMap == nil || target == nil {
|
||||
return fmt.Errorf("dataMap and target cannot be nil")
|
||||
}
|
||||
|
||||
targetValue := reflect.ValueOf(target)
|
||||
if targetValue.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("target must be a pointer to a struct")
|
||||
}
|
||||
|
||||
targetValue = targetValue.Elem()
|
||||
if targetValue.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("target must be a pointer to a struct")
|
||||
}
|
||||
|
||||
targetType := targetValue.Type()
|
||||
|
||||
// Create a map of column names to field indices for faster lookup
|
||||
columnToField := make(map[string]int)
|
||||
for i := 0; i < targetType.NumField(); i++ {
|
||||
field := targetType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build list of possible column names for this field
|
||||
var columnNames []string
|
||||
|
||||
// 1. Bun tag
|
||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Gorm tag
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. JSON tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
columnNames = append(columnNames, parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Field name variations
|
||||
columnNames = append(columnNames, field.Name)
|
||||
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||
columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||
|
||||
// Map all column name variations to this field index
|
||||
for _, colName := range columnNames {
|
||||
columnToField[strings.ToLower(colName)] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate through the map and set struct fields
|
||||
for key, value := range dataMap {
|
||||
// Find the field index for this key
|
||||
fieldIndex, found := columnToField[strings.ToLower(key)]
|
||||
if !found {
|
||||
// Skip keys that don't map to any field
|
||||
continue
|
||||
}
|
||||
|
||||
field := targetValue.Field(fieldIndex)
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Set the value, preserving custom types
|
||||
if err := setFieldValue(field, value); err != nil {
|
||||
return fmt.Errorf("failed to set field %s: %w", targetType.Field(fieldIndex).Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setFieldValue sets a reflect.Value from an interface{} value, handling type conversions
|
||||
func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
if value == nil {
|
||||
// Set zero value for nil
|
||||
field.Set(reflect.Zero(field.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
valueReflect := reflect.ValueOf(value)
|
||||
|
||||
// If types match exactly, just set it
|
||||
if valueReflect.Type().AssignableTo(field.Type()) {
|
||||
field.Set(valueReflect)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle pointer fields
|
||||
if field.Kind() == reflect.Ptr {
|
||||
if valueReflect.Kind() != reflect.Ptr {
|
||||
// Create a new pointer and set its value
|
||||
newPtr := reflect.New(field.Type().Elem())
|
||||
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
||||
return err
|
||||
}
|
||||
field.Set(newPtr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle conversions for basic types
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
if str, ok := value.(string); ok {
|
||||
field.SetString(str)
|
||||
return nil
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if num, ok := convertToInt64(value); ok {
|
||||
if field.OverflowInt(num) {
|
||||
return fmt.Errorf("integer overflow")
|
||||
}
|
||||
field.SetInt(num)
|
||||
return nil
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if num, ok := convertToUint64(value); ok {
|
||||
if field.OverflowUint(num) {
|
||||
return fmt.Errorf("unsigned integer overflow")
|
||||
}
|
||||
field.SetUint(num)
|
||||
return nil
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if num, ok := convertToFloat64(value); ok {
|
||||
if field.OverflowFloat(num) {
|
||||
return fmt.Errorf("float overflow")
|
||||
}
|
||||
field.SetFloat(num)
|
||||
return nil
|
||||
}
|
||||
case reflect.Bool:
|
||||
if b, ok := value.(bool); ok {
|
||||
field.SetBool(b)
|
||||
return nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
// Handle []byte specially (for types like SqlJSONB)
|
||||
if field.Type().Elem().Kind() == reflect.Uint8 {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
field.SetBytes(v)
|
||||
return nil
|
||||
case string:
|
||||
field.SetBytes([]byte(v))
|
||||
return nil
|
||||
case map[string]interface{}, []interface{}:
|
||||
// Marshal complex types to JSON for SqlJSONB fields
|
||||
jsonBytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value to JSON: %w", err)
|
||||
}
|
||||
field.SetBytes(jsonBytes)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||
if field.Kind() == reflect.Struct {
|
||||
|
||||
// Handle datatypes.SqlNull[T] and wrapped types (SqlTimeStamp, SqlDate, SqlTime)
|
||||
// Check if the type has a Scan method (sql.Scanner interface)
|
||||
if field.CanAddr() {
|
||||
scanMethod := field.Addr().MethodByName("Scan")
|
||||
if scanMethod.IsValid() {
|
||||
// Call the Scan method with the value
|
||||
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
||||
if len(results) > 0 {
|
||||
// Check if there was an error
|
||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle time.Time with ISO string fallback
|
||||
if field.Type() == reflect.TypeOf(time.Time{}) {
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
field.Set(reflect.ValueOf(v))
|
||||
return nil
|
||||
case string:
|
||||
// Try parsing as ISO 8601 / RFC3339
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
field.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
}
|
||||
// Try other common formats
|
||||
formats := []string{
|
||||
"2006-01-02T15:04:05.000-0700",
|
||||
"2006-01-02T15:04:05.000",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02",
|
||||
}
|
||||
for _, format := range formats {
|
||||
if t, err := time.Parse(format, v); err == nil {
|
||||
field.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("cannot parse time string: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: Try to find a "Val" field (for SqlNull types) and set it directly
|
||||
valField := field.FieldByName("Val")
|
||||
if valField.IsValid() && valField.CanSet() {
|
||||
// Also set Valid field to true
|
||||
validField := field.FieldByName("Valid")
|
||||
if validField.IsValid() && validField.CanSet() && validField.Kind() == reflect.Bool {
|
||||
// Set the Val field
|
||||
if err := setFieldValue(valField, value); err != nil {
|
||||
return err
|
||||
}
|
||||
// Set Valid to true
|
||||
validField.SetBool(true)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// If we can convert the type, do it
|
||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||
field.Set(valueReflect.Convert(field.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||
}
|
||||
|
||||
// convertToInt64 attempts to convert various types to int64
|
||||
func convertToInt64(value interface{}) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return int64(v), true
|
||||
case int8:
|
||||
return int64(v), true
|
||||
case int16:
|
||||
return int64(v), true
|
||||
case int32:
|
||||
return int64(v), true
|
||||
case int64:
|
||||
return v, true
|
||||
case uint:
|
||||
return int64(v), true
|
||||
case uint8:
|
||||
return int64(v), true
|
||||
case uint16:
|
||||
return int64(v), true
|
||||
case uint32:
|
||||
return int64(v), true
|
||||
case uint64:
|
||||
return int64(v), true
|
||||
case float32:
|
||||
return int64(v), true
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case string:
|
||||
if num, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return num, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// convertToUint64 attempts to convert various types to uint64
|
||||
func convertToUint64(value interface{}) (uint64, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return uint64(v), true
|
||||
case int8:
|
||||
return uint64(v), true
|
||||
case int16:
|
||||
return uint64(v), true
|
||||
case int32:
|
||||
return uint64(v), true
|
||||
case int64:
|
||||
return uint64(v), true
|
||||
case uint:
|
||||
return uint64(v), true
|
||||
case uint8:
|
||||
return uint64(v), true
|
||||
case uint16:
|
||||
return uint64(v), true
|
||||
case uint32:
|
||||
return uint64(v), true
|
||||
case uint64:
|
||||
return v, true
|
||||
case float32:
|
||||
return uint64(v), true
|
||||
case float64:
|
||||
return uint64(v), true
|
||||
case string:
|
||||
if num, err := strconv.ParseUint(v, 10, 64); err == nil {
|
||||
return num, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// convertToFloat64 attempts to convert various types to float64
|
||||
func convertToFloat64(value interface{}) (float64, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return float64(v), true
|
||||
case int8:
|
||||
return float64(v), true
|
||||
case int16:
|
||||
return float64(v), true
|
||||
case int32:
|
||||
return float64(v), true
|
||||
case int64:
|
||||
return float64(v), true
|
||||
case uint:
|
||||
return float64(v), true
|
||||
case uint8:
|
||||
return float64(v), true
|
||||
case uint16:
|
||||
return float64(v), true
|
||||
case uint32:
|
||||
return float64(v), true
|
||||
case uint64:
|
||||
return float64(v), true
|
||||
case float32:
|
||||
return float64(v), true
|
||||
case float64:
|
||||
return v, true
|
||||
case string:
|
||||
if num, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
return num, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
|
||||
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package reflection_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||
)
|
||||
|
||||
func TestMapToStruct_SqlJSONB_PreservesDriverValuer(t *testing.T) {
|
||||
// Test that SqlJSONB type preserves driver.Valuer interface
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||
}
|
||||
|
||||
dataMap := map[string]interface{}{
|
||||
"id": int64(123),
|
||||
"meta": map[string]interface{}{
|
||||
"key": "value",
|
||||
"num": 42,
|
||||
},
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify the field was set
|
||||
if result.ID != 123 {
|
||||
t.Errorf("ID = %v, want 123", result.ID)
|
||||
}
|
||||
|
||||
// Verify SqlJSONB was populated
|
||||
if len(result.Meta) == 0 {
|
||||
t.Error("Meta is empty, want non-empty")
|
||||
}
|
||||
|
||||
// Most importantly: verify driver.Valuer interface works
|
||||
value, err := result.Meta.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Meta.Value() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Value should return a string representation of the JSON
|
||||
if value == nil {
|
||||
t.Error("Meta.Value() returned nil, want non-nil")
|
||||
}
|
||||
|
||||
// Check it's a valid JSON string
|
||||
if str, ok := value.(string); ok {
|
||||
if len(str) == 0 {
|
||||
t.Error("Meta.Value() returned empty string, want valid JSON")
|
||||
}
|
||||
t.Logf("SqlJSONB.Value() returned: %s", str)
|
||||
} else {
|
||||
t.Errorf("Meta.Value() returned type %T, want string", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStruct_SqlJSONB_FromBytes(t *testing.T) {
|
||||
// Test that SqlJSONB can be set from []byte directly
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||
}
|
||||
|
||||
jsonBytes := []byte(`{"direct":"bytes"}`)
|
||||
dataMap := map[string]interface{}{
|
||||
"id": int64(456),
|
||||
"meta": jsonBytes,
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
if result.ID != 456 {
|
||||
t.Errorf("ID = %v, want 456", result.ID)
|
||||
}
|
||||
|
||||
if string(result.Meta) != string(jsonBytes) {
|
||||
t.Errorf("Meta = %s, want %s", string(result.Meta), string(jsonBytes))
|
||||
}
|
||||
|
||||
// Verify driver.Valuer works
|
||||
value, err := result.Meta.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Meta.Value() error = %v", err)
|
||||
}
|
||||
if value == nil {
|
||||
t.Error("Meta.Value() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStruct_AllSqlTypes(t *testing.T) {
|
||||
// Test model with all SQL custom types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||
LoginTime spectypes.SqlTime `bun:"login_time" json:"login_time"`
|
||||
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||
Tags spectypes.SqlJSONB `bun:"tags" json:"tags"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
birthDate := time.Date(1990, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
loginTime := time.Date(0, 1, 1, 14, 30, 0, 0, time.UTC)
|
||||
|
||||
dataMap := map[string]interface{}{
|
||||
"id": int64(100),
|
||||
"name": "Test User",
|
||||
"created_at": now,
|
||||
"birth_date": birthDate,
|
||||
"login_time": loginTime,
|
||||
"meta": map[string]interface{}{
|
||||
"role": "admin",
|
||||
"active": true,
|
||||
},
|
||||
"tags": []interface{}{"golang", "testing", "sql"},
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify basic fields
|
||||
if result.ID != 100 {
|
||||
t.Errorf("ID = %v, want 100", result.ID)
|
||||
}
|
||||
if result.Name != "Test User" {
|
||||
t.Errorf("Name = %v, want 'Test User'", result.Name)
|
||||
}
|
||||
|
||||
// Verify SqlTimeStamp
|
||||
if !result.CreatedAt.Valid {
|
||||
t.Error("CreatedAt.Valid = false, want true")
|
||||
}
|
||||
if !result.CreatedAt.Val.Equal(now) {
|
||||
t.Errorf("CreatedAt.Val = %v, want %v", result.CreatedAt.Val, now)
|
||||
}
|
||||
|
||||
// Verify driver.Valuer for SqlTimeStamp
|
||||
tsValue, err := result.CreatedAt.Value()
|
||||
if err != nil {
|
||||
t.Errorf("CreatedAt.Value() error = %v", err)
|
||||
}
|
||||
if tsValue == nil {
|
||||
t.Error("CreatedAt.Value() returned nil")
|
||||
}
|
||||
|
||||
// Verify SqlDate
|
||||
if !result.BirthDate.Valid {
|
||||
t.Error("BirthDate.Valid = false, want true")
|
||||
}
|
||||
if !result.BirthDate.Val.Equal(birthDate) {
|
||||
t.Errorf("BirthDate.Val = %v, want %v", result.BirthDate.Val, birthDate)
|
||||
}
|
||||
|
||||
// Verify driver.Valuer for SqlDate
|
||||
dateValue, err := result.BirthDate.Value()
|
||||
if err != nil {
|
||||
t.Errorf("BirthDate.Value() error = %v", err)
|
||||
}
|
||||
if dateValue == nil {
|
||||
t.Error("BirthDate.Value() returned nil")
|
||||
}
|
||||
|
||||
// Verify SqlTime
|
||||
if !result.LoginTime.Valid {
|
||||
t.Error("LoginTime.Valid = false, want true")
|
||||
}
|
||||
|
||||
// Verify driver.Valuer for SqlTime
|
||||
timeValue, err := result.LoginTime.Value()
|
||||
if err != nil {
|
||||
t.Errorf("LoginTime.Value() error = %v", err)
|
||||
}
|
||||
if timeValue == nil {
|
||||
t.Error("LoginTime.Value() returned nil")
|
||||
}
|
||||
|
||||
// Verify SqlJSONB for Meta
|
||||
if len(result.Meta) == 0 {
|
||||
t.Error("Meta is empty")
|
||||
}
|
||||
metaValue, err := result.Meta.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Meta.Value() error = %v", err)
|
||||
}
|
||||
if metaValue == nil {
|
||||
t.Error("Meta.Value() returned nil")
|
||||
}
|
||||
|
||||
// Verify SqlJSONB for Tags
|
||||
if len(result.Tags) == 0 {
|
||||
t.Error("Tags is empty")
|
||||
}
|
||||
tagsValue, err := result.Tags.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Tags.Value() error = %v", err)
|
||||
}
|
||||
if tagsValue == nil {
|
||||
t.Error("Tags.Value() returned nil")
|
||||
}
|
||||
|
||||
t.Logf("All SQL types successfully preserved driver.Valuer interface:")
|
||||
t.Logf(" - SqlTimeStamp: %v", tsValue)
|
||||
t.Logf(" - SqlDate: %v", dateValue)
|
||||
t.Logf(" - SqlTime: %v", timeValue)
|
||||
t.Logf(" - SqlJSONB (Meta): %v", metaValue)
|
||||
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
|
||||
}
|
||||
|
||||
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
|
||||
// Test that SqlNull types handle nil values correctly
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
UpdatedAt spectypes.SqlTimeStamp `bun:"updated_at" json:"updated_at"`
|
||||
DeletedAt spectypes.SqlTimeStamp `bun:"deleted_at" json:"deleted_at"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
dataMap := map[string]interface{}{
|
||||
"id": int64(200),
|
||||
"updated_at": now,
|
||||
"deleted_at": nil, // Explicitly nil
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// UpdatedAt should be valid
|
||||
if !result.UpdatedAt.Valid {
|
||||
t.Error("UpdatedAt.Valid = false, want true")
|
||||
}
|
||||
if !result.UpdatedAt.Val.Equal(now) {
|
||||
t.Errorf("UpdatedAt.Val = %v, want %v", result.UpdatedAt.Val, now)
|
||||
}
|
||||
|
||||
// DeletedAt should be invalid (null)
|
||||
if result.DeletedAt.Valid {
|
||||
t.Error("DeletedAt.Valid = true, want false (null)")
|
||||
}
|
||||
|
||||
// Verify driver.Valuer for null SqlTimeStamp
|
||||
deletedValue, err := result.DeletedAt.Value()
|
||||
if err != nil {
|
||||
t.Errorf("DeletedAt.Value() error = %v", err)
|
||||
}
|
||||
if deletedValue != nil {
|
||||
t.Errorf("DeletedAt.Value() = %v, want nil", deletedValue)
|
||||
}
|
||||
}
|
||||
@@ -1687,3 +1687,201 @@ func TestGetRelationModel_WithTags(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStruct(t *testing.T) {
|
||||
// Test model with various field types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
Age int `bun:"age" json:"age"`
|
||||
Active bool `bun:"active" json:"active"`
|
||||
Score float64 `bun:"score" json:"score"`
|
||||
Data []byte `bun:"data" json:"data"`
|
||||
MetaJSON []byte `bun:"meta_json" json:"meta_json"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataMap map[string]interface{}
|
||||
expected TestModel
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Basic types conversion",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(123),
|
||||
"name": "Test User",
|
||||
"age": 30,
|
||||
"active": true,
|
||||
"score": 95.5,
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 123,
|
||||
Name: "Test User",
|
||||
Age: 30,
|
||||
Active: true,
|
||||
Score: 95.5,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Byte slice (SqlJSONB-like) from []byte",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(456),
|
||||
"name": "JSON Test",
|
||||
"data": []byte(`{"key":"value"}`),
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 456,
|
||||
Name: "JSON Test",
|
||||
Data: []byte(`{"key":"value"}`),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Byte slice from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(789),
|
||||
"data": "string data",
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 789,
|
||||
Data: []byte("string data"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Byte slice from map (JSON marshal)",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(999),
|
||||
"meta_json": map[string]interface{}{
|
||||
"field1": "value1",
|
||||
"field2": 42,
|
||||
},
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 999,
|
||||
MetaJSON: []byte(`{"field1":"value1","field2":42}`),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Byte slice from slice (JSON marshal)",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(111),
|
||||
"meta_json": []interface{}{"item1", "item2", 3},
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 111,
|
||||
MetaJSON: []byte(`["item1","item2",3]`),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Field matching by bun tag",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(222),
|
||||
"name": "Tagged Field",
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 222,
|
||||
Name: "Tagged Field",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil values",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(333),
|
||||
"data": nil,
|
||||
},
|
||||
expected: TestModel{
|
||||
ID: 333,
|
||||
Data: nil,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result TestModel
|
||||
err := MapToStruct(tt.dataMap, &result)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Compare fields individually for better error messages
|
||||
if result.ID != tt.expected.ID {
|
||||
t.Errorf("ID = %v, want %v", result.ID, tt.expected.ID)
|
||||
}
|
||||
if result.Name != tt.expected.Name {
|
||||
t.Errorf("Name = %v, want %v", result.Name, tt.expected.Name)
|
||||
}
|
||||
if result.Age != tt.expected.Age {
|
||||
t.Errorf("Age = %v, want %v", result.Age, tt.expected.Age)
|
||||
}
|
||||
if result.Active != tt.expected.Active {
|
||||
t.Errorf("Active = %v, want %v", result.Active, tt.expected.Active)
|
||||
}
|
||||
if result.Score != tt.expected.Score {
|
||||
t.Errorf("Score = %v, want %v", result.Score, tt.expected.Score)
|
||||
}
|
||||
|
||||
// For byte slices, compare as strings for JSON data
|
||||
if tt.expected.Data != nil {
|
||||
if string(result.Data) != string(tt.expected.Data) {
|
||||
t.Errorf("Data = %s, want %s", string(result.Data), string(tt.expected.Data))
|
||||
}
|
||||
}
|
||||
if tt.expected.MetaJSON != nil {
|
||||
if string(result.MetaJSON) != string(tt.expected.MetaJSON) {
|
||||
t.Errorf("MetaJSON = %s, want %s", string(result.MetaJSON), string(tt.expected.MetaJSON))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStruct_Errors(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int `bun:"id" json:"id"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataMap map[string]interface{}
|
||||
target interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Nil dataMap",
|
||||
dataMap: nil,
|
||||
target: &TestModel{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Nil target",
|
||||
dataMap: map[string]interface{}{"id": 1},
|
||||
target: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Non-pointer target",
|
||||
dataMap: map[string]interface{}{"id": 1},
|
||||
target: TestModel{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := MapToStruct(tt.dataMap, tt.target)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
703
pkg/resolvespec/README.md
Normal file
703
pkg/resolvespec/README.md
Normal file
@@ -0,0 +1,703 @@
|
||||
# ResolveSpec - Body-Based REST API
|
||||
|
||||
ResolveSpec provides a REST API where query options are passed in the JSON request body. This approach offers GraphQL-like flexibility while maintaining RESTful principles, making it ideal for complex queries and operations.
|
||||
|
||||
## Features
|
||||
|
||||
* **Body-Based Querying**: All query options passed via JSON request body
|
||||
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||
* **Offset Pagination**: Traditional limit/offset pagination support
|
||||
* **Advanced Filtering**: Multiple operators, AND/OR logic, and custom SQL
|
||||
* **Relationship Preloading**: Load related entities with custom column selection and filters
|
||||
* **Recursive CRUD**: Automatically handle nested object graphs with foreign key resolution
|
||||
* **Computed Columns**: Define virtual columns with SQL expressions
|
||||
* **Database-Agnostic**: Works with GORM, Bun, or custom database adapters
|
||||
* **Router-Agnostic**: Integrates with any HTTP router through standard interfaces
|
||||
* **Type-Safe**: Strong type validation and conversion
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Setup with GORM
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
import "github.com/gorilla/mux"
|
||||
|
||||
// Create handler
|
||||
handler := resolvespec.NewHandlerWithGORM(db)
|
||||
|
||||
// IMPORTANT: Register models BEFORE setting up routes
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Start server
|
||||
http.ListenAndServe(":8080", router)
|
||||
```
|
||||
|
||||
### Setup with Bun ORM
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
import "github.com/uptrace/bun"
|
||||
|
||||
// Create handler with Bun
|
||||
handler := resolvespec.NewHandlerWithBun(bunDB)
|
||||
|
||||
// Register models
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
|
||||
// Setup routes (same as GORM)
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Simple Read Request
|
||||
|
||||
```http
|
||||
POST /core/users HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
```
|
||||
|
||||
### With Preloading
|
||||
|
||||
```http
|
||||
POST /core/users HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
```
|
||||
|
||||
## Request Structure
|
||||
|
||||
### Request Format
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read|create|update|delete",
|
||||
"data": {
|
||||
// For create/update operations
|
||||
},
|
||||
"options": {
|
||||
"columns": [...],
|
||||
"preload": [...],
|
||||
"filters": [...],
|
||||
"sort": [...],
|
||||
"limit": number,
|
||||
"offset": number,
|
||||
"cursor_forward": "string",
|
||||
"cursor_backward": "string",
|
||||
"customOperators": [...],
|
||||
"computedColumns": [...]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Operations
|
||||
|
||||
| Operation | Description | Requires Data | Requires ID |
|
||||
|-----------|-------------|---------------|-------------|
|
||||
| `read` | Fetch records | No | Optional (single record) |
|
||||
| `create` | Create new record(s) | Yes | No |
|
||||
| `update` | Update existing record(s) | Yes | Yes (in URL) |
|
||||
| `delete` | Delete record(s) | No | Yes (in URL) |
|
||||
|
||||
### Options Fields
|
||||
|
||||
| Field | Type | Description | Example |
|
||||
|-------|------|-------------|---------|
|
||||
| `columns` | `[]string` | Columns to select | `["id", "name", "email"]` |
|
||||
| `preload` | `[]PreloadConfig` | Relations to load | See [Preloading](#preloading) |
|
||||
| `filters` | `[]Filter` | Filter conditions | See [Filtering](#filtering) |
|
||||
| `sort` | `[]Sort` | Sort criteria | `[{"column": "created_at", "direction": "desc"}]` |
|
||||
| `limit` | `int` | Max records to return | `50` |
|
||||
| `offset` | `int` | Number of records to skip | `100` |
|
||||
| `cursor_forward` | `string` | Cursor for next page | `"12345"` |
|
||||
| `cursor_backward` | `string` | Cursor for previous page | `"12300"` |
|
||||
| `customOperators` | `[]CustomOperator` | Custom SQL conditions | See [Custom Operators](#custom-operators) |
|
||||
| `computedColumns` | `[]ComputedColumn` | Virtual columns | See [Computed Columns](#computed-columns) |
|
||||
|
||||
## Filtering
|
||||
|
||||
### Available Operators
|
||||
|
||||
| Operator | Description | Example |
|
||||
|----------|-------------|---------|
|
||||
| `eq` | Equal | `{"column": "status", "operator": "eq", "value": "active"}` |
|
||||
| `neq` | Not Equal | `{"column": "status", "operator": "neq", "value": "deleted"}` |
|
||||
| `gt` | Greater Than | `{"column": "age", "operator": "gt", "value": 18}` |
|
||||
| `gte` | Greater Than or Equal | `{"column": "age", "operator": "gte", "value": 18}` |
|
||||
| `lt` | Less Than | `{"column": "price", "operator": "lt", "value": 100}` |
|
||||
| `lte` | Less Than or Equal | `{"column": "price", "operator": "lte", "value": 100}` |
|
||||
| `like` | LIKE pattern | `{"column": "name", "operator": "like", "value": "%john%"}` |
|
||||
| `ilike` | Case-insensitive LIKE | `{"column": "email", "operator": "ilike", "value": "%@example.com"}` |
|
||||
| `in` | IN clause | `{"column": "status", "operator": "in", "value": ["active", "pending"]}` |
|
||||
| `contains` | Contains string | `{"column": "description", "operator": "contains", "value": "important"}` |
|
||||
| `startswith` | Starts with string | `{"column": "name", "operator": "startswith", "value": "John"}` |
|
||||
| `endswith` | Ends with string | `{"column": "email", "operator": "endswith", "value": "@example.com"}` |
|
||||
| `between` | Between (exclusive) | `{"column": "age", "operator": "between", "value": [18, 65]}` |
|
||||
| `betweeninclusive` | Between (inclusive) | `{"column": "price", "operator": "betweeninclusive", "value": [10, 100]}` |
|
||||
| `empty` | IS NULL or empty | `{"column": "deleted_at", "operator": "empty"}` |
|
||||
| `notempty` | IS NOT NULL | `{"column": "email", "operator": "notempty"}` |
|
||||
|
||||
### Complex Filtering Example
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
},
|
||||
{
|
||||
"column": "email",
|
||||
"operator": "ilike",
|
||||
"value": "%@company.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Preloading
|
||||
|
||||
Load related entities with custom configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email"],
|
||||
"preload": [
|
||||
{
|
||||
"relation": "posts",
|
||||
"columns": ["id", "title", "created_at"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "published"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"limit": 5
|
||||
},
|
||||
{
|
||||
"relation": "profile",
|
||||
"columns": ["bio", "website"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
Efficient pagination for large datasets:
|
||||
|
||||
### First Request (No Cursor)
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "id",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"limit": 50
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Next Page (Forward Cursor)
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "id",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"limit": 50,
|
||||
"cursor_forward": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Previous Page (Backward Cursor)
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "id",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"limit": 50,
|
||||
"cursor_backward": "12300"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits over offset pagination**:
|
||||
* Consistent results when data changes
|
||||
* Better performance for large offsets
|
||||
* Prevents "skipped" or duplicate records
|
||||
* Works with complex sort expressions
|
||||
|
||||
## Recursive CRUD Operations
|
||||
|
||||
Automatically handle nested object graphs with intelligent foreign key resolution.
|
||||
|
||||
### Creating Nested Objects
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "create",
|
||||
"data": {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"posts": [
|
||||
{
|
||||
"title": "My First Post",
|
||||
"content": "Hello World",
|
||||
"tags": [
|
||||
{"name": "tech"},
|
||||
{"name": "programming"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Second Post",
|
||||
"content": "More content"
|
||||
}
|
||||
],
|
||||
"profile": {
|
||||
"bio": "Software Developer",
|
||||
"website": "https://example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Per-Record Operation Control with `_request`
|
||||
|
||||
Control individual operations for each nested record:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "update",
|
||||
"data": {
|
||||
"name": "John Updated",
|
||||
"posts": [
|
||||
{
|
||||
"_request": "insert",
|
||||
"title": "New Post",
|
||||
"content": "Fresh content"
|
||||
},
|
||||
{
|
||||
"_request": "update",
|
||||
"id": 456,
|
||||
"title": "Updated Post Title"
|
||||
},
|
||||
{
|
||||
"_request": "delete",
|
||||
"id": 789
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Supported `_request` values**:
|
||||
* `insert` - Create a new related record
|
||||
* `update` - Update an existing related record
|
||||
* `delete` - Delete a related record
|
||||
* `upsert` - Create if doesn't exist, update if exists
|
||||
|
||||
**How It Works**:
|
||||
1. Automatic foreign key resolution - parent IDs propagate to children
|
||||
2. Recursive processing - handles nested relationships at any depth
|
||||
3. Transaction safety - all operations execute atomically
|
||||
4. Relationship detection - automatically detects belongsTo, hasMany, hasOne, many2many
|
||||
5. Flexible operations - mix create, update, and delete in one request
|
||||
|
||||
## Computed Columns
|
||||
|
||||
Define virtual columns using SQL expressions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "first_name", "last_name"],
|
||||
"computedColumns": [
|
||||
{
|
||||
"name": "full_name",
|
||||
"expression": "CONCAT(first_name, ' ', last_name)"
|
||||
},
|
||||
{
|
||||
"name": "age_years",
|
||||
"expression": "EXTRACT(YEAR FROM AGE(birth_date))"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Custom Operators
|
||||
|
||||
Add custom SQL conditions when needed:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"condition": "LOWER(email) LIKE ?",
|
||||
"values": ["%@example.com"]
|
||||
},
|
||||
{
|
||||
"condition": "created_at > NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
Register hooks for all CRUD operations:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
|
||||
// Create handler
|
||||
handler := resolvespec.NewHandlerWithGORM(db)
|
||||
|
||||
// Register a before-read hook (e.g., for authorization)
|
||||
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
|
||||
// Check permissions
|
||||
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||
}
|
||||
|
||||
// Modify query options
|
||||
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
|
||||
ctx.Options.Limit = ptr(100) // Enforce max limit
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register an after-read hook (e.g., for data transformation)
|
||||
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
|
||||
// Transform or filter results
|
||||
if users, ok := ctx.Result.([]User); ok {
|
||||
for i := range users {
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register a before-create hook (e.g., for validation)
|
||||
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
|
||||
// Validate data
|
||||
if user, ok := ctx.Data.(*User); ok {
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
// Add timestamps
|
||||
user.CreatedAt = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
**Available Hook Types**:
|
||||
* `BeforeRead`, `AfterRead`
|
||||
* `BeforeCreate`, `AfterCreate`
|
||||
* `BeforeUpdate`, `AfterUpdate`
|
||||
* `BeforeDelete`, `AfterDelete`
|
||||
|
||||
**HookContext** provides:
|
||||
* `Context`: Request context
|
||||
* `Handler`: Access to handler, database, and registry
|
||||
* `Schema`, `Entity`, `TableName`: Request info
|
||||
* `Model`: The registered model type
|
||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||
* `ID`: Record ID (for single-record operations)
|
||||
* `Data`: Request data (for create/update)
|
||||
* `Result`: Operation result (for after hooks)
|
||||
* `Writer`: Response writer (allows hooks to modify response)
|
||||
|
||||
## Model Registration
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||
Profile *Profile `json:"profile,omitempty" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
UserID uint `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||
}
|
||||
|
||||
// Schema.Table format
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
UserID uint `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Connect to database
|
||||
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create handler
|
||||
handler := resolvespec.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
|
||||
// Add hooks
|
||||
handler.Hooks().Register(resolvespec.BeforeRead, func(ctx *resolvespec.HookContext) error {
|
||||
log.Printf("Reading %s", ctx.Entity)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Start server
|
||||
log.Println("Server starting on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
ResolveSpec is designed for testability:
|
||||
|
||||
```go
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserRead(t *testing.T) {
|
||||
handler := resolvespec.NewHandlerWithGORM(testDB)
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"operation": "read",
|
||||
"options": map[string]interface{}{
|
||||
"columns": []string{"id", "name"},
|
||||
"limit": 10,
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/core/users", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Test your handler...
|
||||
}
|
||||
```
|
||||
|
||||
## Router Integration
|
||||
|
||||
### Gorilla Mux
|
||||
|
||||
```go
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
```
|
||||
|
||||
### BunRouter
|
||||
|
||||
```go
|
||||
router := bunrouter.New()
|
||||
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
|
||||
```
|
||||
|
||||
### Custom Routers
|
||||
|
||||
```go
|
||||
// Implement custom integration using common.Request and common.ResponseWriter
|
||||
router.POST("/:schema/:entity", func(w http.ResponseWriter, r *http.Request) {
|
||||
params := extractParams(r) // Your param extraction logic
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
})
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
### Success Response
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 50,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Response
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": {
|
||||
"code": "validation_error",
|
||||
"message": "Invalid request",
|
||||
"details": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
* [Main README](../../README.md) - ResolveSpec overview
|
||||
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
|
||||
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||
|
||||
## License
|
||||
|
||||
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
### Success Response
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 50,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Response
|
||||
|
||||
```json
|
||||
{
|
||||
"success": false,
|
||||
"error": {
|
||||
"code": "validation_error",
|
||||
"message": "Invalid request",
|
||||
"details": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
* [Main README](../../README.md) - ResolveSpec overview
|
||||
* [RestHeadSpec Package](../restheadspec/README.md) - Header-based API
|
||||
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||
|
||||
## License
|
||||
|
||||
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||
118
pkg/resolvespec/cache_helpers.go
Normal file
118
pkg/resolvespec/cache_helpers.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// queryCacheKey represents the components used to build a cache key for query total count
|
||||
type queryCacheKey struct {
|
||||
TableName string `json:"table_name"`
|
||||
Filters []common.FilterOption `json:"filters"`
|
||||
Sort []common.SortOption `json:"sort"`
|
||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||
CursorForward string `json:"cursor_forward,omitempty"`
|
||||
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||
}
|
||||
|
||||
// cachedTotal represents a cached total count
|
||||
type cachedTotal struct {
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// buildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||
func buildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||
key := queryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
}
|
||||
|
||||
// buildExtendedQueryCacheKey builds a cache key for extended query options with cursor pagination
|
||||
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
customWhere, customOr string, cursorFwd, cursorBwd string) string {
|
||||
|
||||
key := queryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
CursorForward: cursorFwd,
|
||||
CursorBackward: cursorBwd,
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, cursorFwd, cursorBwd))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
}
|
||||
|
||||
// hashString computes SHA256 hash of a string
|
||||
func hashString(s string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||
func getQueryTotalCacheKey(hash string) string {
|
||||
return fmt.Sprintf("query_total:%s", hash)
|
||||
}
|
||||
|
||||
// buildCacheTags creates cache tags from schema and table name
|
||||
func buildCacheTags(schema, tableName string) []string {
|
||||
return []string{
|
||||
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||
}
|
||||
}
|
||||
|
||||
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||
c := cache.GetDefaultCache()
|
||||
cacheData := cachedTotal{Total: total}
|
||||
tags := buildCacheTags(schema, tableName)
|
||||
|
||||
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||
}
|
||||
|
||||
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||
c := cache.GetDefaultCache()
|
||||
|
||||
// Invalidate for each tag
|
||||
for _, tag := range tags {
|
||||
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -330,19 +331,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Use extended cache key if cursors are present
|
||||
var cacheKeyHash string
|
||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
|
||||
cacheKeyHash = buildExtendedQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
"", // No custom SQL WHERE in resolvespec
|
||||
"", // No custom SQL OR in resolvespec
|
||||
nil, // No expand options in resolvespec
|
||||
false, // distinct not used here
|
||||
"", // No custom SQL WHERE in resolvespec
|
||||
"", // No custom SQL OR in resolvespec
|
||||
options.CursorForward,
|
||||
options.CursorBackward,
|
||||
)
|
||||
} else {
|
||||
cacheKeyHash = cache.BuildQueryCacheKey(
|
||||
cacheKeyHash = buildQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
@@ -350,10 +349,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
"", // No custom SQL OR in resolvespec
|
||||
)
|
||||
}
|
||||
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
|
||||
cacheKey := getQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Try to retrieve from cache
|
||||
var cachedTotal cache.CachedTotal
|
||||
var cachedTotal cachedTotal
|
||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||
if err == nil {
|
||||
total = cachedTotal.Total
|
||||
@@ -370,10 +369,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
total = count
|
||||
logger.Debug("Total records (from query): %d", total)
|
||||
|
||||
// Store in cache
|
||||
// Store in cache with schema and table tags
|
||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||
cacheData := cache.CachedTotal{Total: total}
|
||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
||||
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||
logger.Warn("Failed to cache query total: %v", err)
|
||||
// Don't fail the request if caching fails
|
||||
} else {
|
||||
@@ -463,6 +461,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, result.Data, nil)
|
||||
return
|
||||
}
|
||||
@@ -479,6 +482,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
@@ -517,6 +525,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created %d records with nested data", len(results))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, results, nil)
|
||||
return
|
||||
}
|
||||
@@ -540,6 +553,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created %d records", len(v))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
|
||||
case []interface{}:
|
||||
@@ -583,6 +601,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created %d records with nested data", len(results))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, results, nil)
|
||||
return
|
||||
}
|
||||
@@ -610,6 +633,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created %d records", len(v))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
|
||||
default:
|
||||
@@ -660,6 +688,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, result.Data, nil)
|
||||
return
|
||||
}
|
||||
@@ -696,6 +729,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, data, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
@@ -734,6 +772,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, results, nil)
|
||||
return
|
||||
}
|
||||
@@ -757,6 +800,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(updates))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, updates, nil)
|
||||
|
||||
case []interface{}:
|
||||
@@ -799,6 +847,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, results, nil)
|
||||
return
|
||||
}
|
||||
@@ -826,6 +879,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(list))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
|
||||
default:
|
||||
@@ -872,6 +930,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", len(v))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
||||
return
|
||||
|
||||
@@ -913,6 +976,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", deletedCount)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||
return
|
||||
|
||||
@@ -939,6 +1007,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", deletedCount)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||
return
|
||||
|
||||
@@ -957,7 +1030,29 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
// Get primary key name
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// First, fetch the record that will be deleted
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
recordToDelete := reflect.New(modelType).Interface()
|
||||
|
||||
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Warn("Record not found for delete: %s = %s", pkName, id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", err)
|
||||
return
|
||||
}
|
||||
logger.Error("Error fetching record for delete: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Error fetching record", err)
|
||||
return
|
||||
}
|
||||
|
||||
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
@@ -966,14 +1061,21 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the record was actually deleted
|
||||
if result.RowsAffected() == 0 {
|
||||
logger.Warn("No record found to delete with ID: %s", id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil)
|
||||
logger.Warn("No rows deleted for ID: %s", id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found or already deleted", nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully deleted record with ID: %s", id)
|
||||
h.sendResponse(w, nil, nil)
|
||||
// Return the deleted record data
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, recordToDelete, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
|
||||
@@ -56,6 +56,10 @@ type HookContext struct {
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
|
||||
445
pkg/restheadspec/README.md
Normal file
445
pkg/restheadspec/README.md
Normal file
@@ -0,0 +1,445 @@
|
||||
# RestHeadSpec - Header-Based REST API
|
||||
|
||||
RestHeadSpec provides a REST API where all query options are passed via HTTP headers instead of the request body. This provides cleaner separation between data and metadata, making it ideal for GET requests and RESTful architectures.
|
||||
|
||||
## Features
|
||||
|
||||
* **Header-Based Querying**: All query options via HTTP headers
|
||||
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||
* **Advanced Filtering**: Field filters, search operators, AND/OR logic
|
||||
* **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
|
||||
* **Single Record as Object**: Automatically return single-element arrays as objects (default)
|
||||
* **Base64 Support**: Base64-encoded header values for complex queries
|
||||
* **Type-Aware Filtering**: Automatic type detection and conversion
|
||||
* **CORS Support**: Comprehensive CORS headers for cross-origin requests
|
||||
* **OPTIONS Method**: Full OPTIONS support for CORS preflight
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Setup with GORM
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
import "github.com/gorilla/mux"
|
||||
|
||||
// Create handler
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// IMPORTANT: Register models BEFORE setting up routes
|
||||
handler.Registry.RegisterModel("public.users", &User{})
|
||||
handler.Registry.RegisterModel("public.posts", &Post{})
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Start server
|
||||
http.ListenAndServe(":8080", router)
|
||||
```
|
||||
|
||||
### Setup with Bun ORM
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
import "github.com/uptrace/bun"
|
||||
|
||||
// Create handler with Bun
|
||||
handler := restheadspec.NewHandlerWithBun(bunDB)
|
||||
|
||||
// Register models
|
||||
handler.Registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Setup routes (same as GORM)
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
```
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Simple GET Request
|
||||
|
||||
```http
|
||||
GET /public/users HTTP/1.1
|
||||
Host: api.example.com
|
||||
X-Select-Fields: id,name,email
|
||||
X-FieldFilter-Status: active
|
||||
X-Sort: -created_at
|
||||
X-Limit: 50
|
||||
```
|
||||
|
||||
### With Preloading
|
||||
|
||||
```http
|
||||
GET /public/users HTTP/1.1
|
||||
X-Select-Fields: id,name,email,department_id
|
||||
X-Preload: department:id,name
|
||||
X-FieldFilter-Status: active
|
||||
X-Limit: 50
|
||||
```
|
||||
|
||||
## Common Headers
|
||||
|
||||
| Header | Description | Example |
|
||||
|--------|-------------|---------|
|
||||
| `X-Select-Fields` | Columns to include | `id,name,email` |
|
||||
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
|
||||
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
|
||||
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
|
||||
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
|
||||
| `X-Preload` | Preload relations | `posts:id,title` |
|
||||
| `X-Sort` | Sort columns | `-created_at,+name` |
|
||||
| `X-Limit` | Limit results | `50` |
|
||||
| `X-Offset` | Offset for pagination | `100` |
|
||||
| `X-Clean-JSON` | Remove null/empty fields | `true` |
|
||||
| `X-Single-Record-As-Object` | Return single records as objects | `false` |
|
||||
|
||||
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
|
||||
|
||||
For complete header documentation, see [HEADERS.md](HEADERS.md).
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
RestHeadSpec supports lifecycle hooks for all CRUD operations:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
|
||||
// Create handler
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// Register a before-read hook (e.g., for authorization)
|
||||
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||
// Check permissions
|
||||
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||
}
|
||||
|
||||
// Modify query options
|
||||
ctx.Options.Limit = ptr(100) // Enforce max limit
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register an after-read hook (e.g., for data transformation)
|
||||
handler.Hooks.Register(restheadspec.AfterRead, func(ctx *restheadspec.HookContext) error {
|
||||
// Transform or filter results
|
||||
if users, ok := ctx.Result.([]User); ok {
|
||||
for i := range users {
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register a before-create hook (e.g., for validation)
|
||||
handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookContext) error {
|
||||
// Validate data
|
||||
if user, ok := ctx.Data.(*User); ok {
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
// Add timestamps
|
||||
user.CreatedAt = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
**Available Hook Types**:
|
||||
* `BeforeRead`, `AfterRead`
|
||||
* `BeforeCreate`, `AfterCreate`
|
||||
* `BeforeUpdate`, `AfterUpdate`
|
||||
* `BeforeDelete`, `AfterDelete`
|
||||
|
||||
**HookContext** provides:
|
||||
* `Context`: Request context
|
||||
* `Handler`: Access to handler, database, and registry
|
||||
* `Schema`, `Entity`, `TableName`: Request info
|
||||
* `Model`: The registered model type
|
||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||
* `ID`: Record ID (for single-record operations)
|
||||
* `Data`: Request data (for create/update)
|
||||
* `Result`: Operation result (for after hooks)
|
||||
* `Writer`: Response writer (allows hooks to modify response)
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
RestHeadSpec supports efficient cursor-based pagination for large datasets:
|
||||
|
||||
```http
|
||||
GET /public/posts HTTP/1.1
|
||||
X-Sort: -created_at,+id
|
||||
X-Limit: 50
|
||||
X-Cursor-Forward: <cursor_token>
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
1. First request returns results + cursor token in response
|
||||
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
|
||||
3. Cursor maintains consistent ordering even with data changes
|
||||
4. Supports complex multi-column sorting
|
||||
|
||||
**Benefits over offset pagination**:
|
||||
* Consistent results when data changes
|
||||
* Better performance for large offsets
|
||||
* Prevents "skipped" or duplicate records
|
||||
* Works with complex sort expressions
|
||||
|
||||
**Example with hooks**:
|
||||
|
||||
```go
|
||||
// Enable cursor pagination in a hook
|
||||
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||
// For large tables, enforce cursor pagination
|
||||
if ctx.Entity == "posts" && ctx.Options.Offset != nil && *ctx.Options.Offset > 1000 {
|
||||
return fmt.Errorf("use cursor pagination for large offsets")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
## Response Formats
|
||||
|
||||
RestHeadSpec supports multiple response formats:
|
||||
|
||||
**1. Simple Format** (`X-SimpleApi: true`):
|
||||
|
||||
```json
|
||||
[
|
||||
{ "id": 1, "name": "John" },
|
||||
{ "id": 2, "name": "Jane" }
|
||||
]
|
||||
```
|
||||
|
||||
**2. Detail Format** (`X-DetailApi: true`, default):
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 100,
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Syncfusion Format** (`X-Syncfusion: true`):
|
||||
|
||||
```json
|
||||
{
|
||||
"result": [...],
|
||||
"count": 100
|
||||
}
|
||||
```
|
||||
|
||||
## Single Record as Object (Default Behavior)
|
||||
|
||||
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses.
|
||||
|
||||
**Default behavior (enabled)**:
|
||||
|
||||
```http
|
||||
GET /public/users/123
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": { "id": 123, "name": "John", "email": "john@example.com" }
|
||||
}
|
||||
```
|
||||
|
||||
**To disable** (force arrays):
|
||||
|
||||
```http
|
||||
GET /public/users/123
|
||||
X-Single-Record-As-Object: false
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||
}
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
* When a query returns exactly **one record**, it's returned as an object
|
||||
* When a query returns **multiple records**, they're returned as an array
|
||||
* Set `X-Single-Record-As-Object: false` to always receive arrays
|
||||
* Works with all response formats (simple, detail, syncfusion)
|
||||
* Applies to both read operations and create/update returning clauses
|
||||
|
||||
## CORS & OPTIONS Support
|
||||
|
||||
RestHeadSpec includes comprehensive CORS support for cross-origin requests:
|
||||
|
||||
**OPTIONS Method**:
|
||||
|
||||
```http
|
||||
OPTIONS /public/users HTTP/1.1
|
||||
```
|
||||
|
||||
Returns metadata with appropriate CORS headers:
|
||||
|
||||
```http
|
||||
Access-Control-Allow-Origin: *
|
||||
Access-Control-Allow-Methods: GET, POST, OPTIONS
|
||||
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
|
||||
Access-Control-Max-Age: 86400
|
||||
Access-Control-Allow-Credentials: true
|
||||
```
|
||||
|
||||
**Key Features**:
|
||||
* OPTIONS returns model metadata (same as GET metadata endpoint)
|
||||
* All HTTP methods include CORS headers automatically
|
||||
* OPTIONS requests don't require authentication (CORS preflight)
|
||||
* Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
|
||||
* 24-hour max age to reduce preflight requests
|
||||
|
||||
**Configuration**:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
|
||||
// Get default CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
|
||||
// Customize if needed
|
||||
corsConfig.AllowedOrigins = []string{"https://example.com"}
|
||||
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Base64 Encoding
|
||||
|
||||
For complex header values, use base64 encoding:
|
||||
|
||||
```http
|
||||
GET /public/users HTTP/1.1
|
||||
X-Select-Fields-Base64: aWQsbmFtZSxlbWFpbA==
|
||||
```
|
||||
|
||||
### AND/OR Logic
|
||||
|
||||
Combine multiple filters with AND/OR logic:
|
||||
|
||||
```http
|
||||
GET /public/users HTTP/1.1
|
||||
X-FieldFilter-Status: active
|
||||
X-SearchOp-Gte-Age: 18
|
||||
X-Filter-Logic: AND
|
||||
```
|
||||
|
||||
### Complex Preloading
|
||||
|
||||
Load nested relationships:
|
||||
|
||||
```http
|
||||
GET /public/users HTTP/1.1
|
||||
X-Preload: posts:id,title,comments:id,text,author:name
|
||||
```
|
||||
|
||||
## Model Registration
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Posts []Post `json:"posts,omitempty" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
// Schema.Table format
|
||||
handler.Registry.RegisterModel("public.users", &User{})
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Connect to database
|
||||
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create handler
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models
|
||||
handler.Registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Add hooks
|
||||
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
|
||||
log.Printf("Reading %s", ctx.Entity)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Start server
|
||||
log.Println("Server starting on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
RestHeadSpec is designed for testability:
|
||||
|
||||
```go
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserRead(t *testing.T) {
|
||||
handler := restheadspec.NewHandlerWithGORM(testDB)
|
||||
handler.Registry.RegisterModel("public.users", &User{})
|
||||
|
||||
req := httptest.NewRequest("GET", "/public/users", nil)
|
||||
req.Header.Set("X-Select-Fields", "id,name")
|
||||
req.Header.Set("X-Limit", "10")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
// Test your handler...
|
||||
}
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
* [HEADERS.md](HEADERS.md) - Complete header reference
|
||||
* [Main README](../../README.md) - ResolveSpec overview
|
||||
* [ResolveSpec Package](../resolvespec/README.md) - Body-based API
|
||||
* [StaticWeb Package](../server/staticweb/README.md) - Static file server
|
||||
|
||||
## License
|
||||
|
||||
This package is part of ResolveSpec and is licensed under the MIT License.
|
||||
@@ -1,4 +1,4 @@
|
||||
package cache
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,56 +7,42 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// QueryCacheKey represents the components used to build a cache key for query total count
|
||||
type QueryCacheKey struct {
|
||||
// expandOptionKey represents expand options for cache key
|
||||
type expandOptionKey struct {
|
||||
Relation string `json:"relation"`
|
||||
Where string `json:"where,omitempty"`
|
||||
}
|
||||
|
||||
// queryCacheKey represents the components used to build a cache key for query total count
|
||||
type queryCacheKey struct {
|
||||
TableName string `json:"table_name"`
|
||||
Filters []common.FilterOption `json:"filters"`
|
||||
Sort []common.SortOption `json:"sort"`
|
||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||
Expand []ExpandOptionKey `json:"expand,omitempty"`
|
||||
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||
Distinct bool `json:"distinct,omitempty"`
|
||||
CursorForward string `json:"cursor_forward,omitempty"`
|
||||
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||
}
|
||||
|
||||
// ExpandOptionKey represents expand options for cache key
|
||||
type ExpandOptionKey struct {
|
||||
Relation string `json:"relation"`
|
||||
Where string `json:"where,omitempty"`
|
||||
// cachedTotal represents a cached total count
|
||||
type cachedTotal struct {
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||
// This is used to cache the total count of records matching a query
|
||||
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||
key := QueryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
}
|
||||
|
||||
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||
// Includes expand, distinct, and cursor pagination options
|
||||
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
|
||||
key := QueryCacheKey{
|
||||
key := queryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
@@ -69,11 +55,11 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
|
||||
// Convert expand options to cache key format
|
||||
if len(expandOpts) > 0 {
|
||||
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
|
||||
key.Expand = make([]expandOptionKey, 0, len(expandOpts))
|
||||
for _, exp := range expandOpts {
|
||||
// Type assert to get the expand option fields we care about for caching
|
||||
if expMap, ok := exp.(map[string]interface{}); ok {
|
||||
expKey := ExpandOptionKey{}
|
||||
expKey := expandOptionKey{}
|
||||
if rel, ok := expMap["relation"].(string); ok {
|
||||
expKey.Relation = rel
|
||||
}
|
||||
@@ -83,7 +69,6 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
key.Expand = append(key.Expand, expKey)
|
||||
}
|
||||
}
|
||||
// Sort expand options for consistent hashing (already sorted by relation name above)
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
@@ -104,24 +89,38 @@ func hashString(s string) string {
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||
func GetQueryTotalCacheKey(hash string) string {
|
||||
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||
func getQueryTotalCacheKey(hash string) string {
|
||||
return fmt.Sprintf("query_total:%s", hash)
|
||||
}
|
||||
|
||||
// CachedTotal represents a cached total count
|
||||
type CachedTotal struct {
|
||||
Total int `json:"total"`
|
||||
// buildCacheTags creates cache tags from schema and table name
|
||||
func buildCacheTags(schema, tableName string) []string {
|
||||
return []string{
|
||||
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCacheForTable removes all cached totals for a specific table
|
||||
// This should be called when data in the table changes (insert/update/delete)
|
||||
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
|
||||
cache := GetDefaultCache()
|
||||
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||
c := cache.GetDefaultCache()
|
||||
cacheData := cachedTotal{Total: total}
|
||||
tags := buildCacheTags(schema, tableName)
|
||||
|
||||
// Build a pattern to match all query totals for this table
|
||||
// Note: This requires pattern matching support in the provider
|
||||
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
|
||||
|
||||
return cache.DeleteByPattern(ctx, pattern)
|
||||
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||
}
|
||||
|
||||
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||
c := cache.GetDefaultCache()
|
||||
|
||||
// Invalidate for each tag
|
||||
for _, tag := range tags {
|
||||
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
193
pkg/restheadspec/empty_result_test.go
Normal file
193
pkg/restheadspec/empty_result_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Test that normalizeResultArray returns empty array when no records found without ID
|
||||
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
shouldBeEmptyArr bool
|
||||
}{
|
||||
{
|
||||
name: "nil should return empty array",
|
||||
input: nil,
|
||||
shouldBeEmptyArr: true,
|
||||
},
|
||||
{
|
||||
name: "empty slice should return empty array",
|
||||
input: []*EmptyTestModel{},
|
||||
shouldBeEmptyArr: true,
|
||||
},
|
||||
{
|
||||
name: "single element should return the element",
|
||||
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
|
||||
shouldBeEmptyArr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple elements should return the slice",
|
||||
input: []*EmptyTestModel{
|
||||
{ID: 1, Name: "test1"},
|
||||
{ID: 2, Name: "test2"},
|
||||
},
|
||||
shouldBeEmptyArr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.normalizeResultArray(tt.input)
|
||||
|
||||
// For cases that should return empty array
|
||||
if tt.shouldBeEmptyArr {
|
||||
emptyArr, ok := result.([]interface{})
|
||||
if !ok {
|
||||
t.Errorf("Expected empty array []interface{}{}, got %T: %v", result, result)
|
||||
return
|
||||
}
|
||||
if len(emptyArr) != 0 {
|
||||
t.Errorf("Expected empty array with length 0, got length %d", len(emptyArr))
|
||||
}
|
||||
|
||||
// Verify it serializes to [] and not null
|
||||
jsonBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal result: %v", err)
|
||||
return
|
||||
}
|
||||
if string(jsonBytes) != "[]" {
|
||||
t.Errorf("Expected JSON '[]', got '%s'", string(jsonBytes))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that sendFormattedResponse adds X-No-Data-Found header
|
||||
func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Mock ResponseWriter
|
||||
mockWriter := &MockTestResponseWriter{
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: 0,
|
||||
Count: 0,
|
||||
Filtered: 0,
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
options := ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{},
|
||||
}
|
||||
|
||||
// Test with empty data
|
||||
emptyData := []interface{}{}
|
||||
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options)
|
||||
|
||||
// Check if X-No-Data-Found header was set
|
||||
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
||||
}
|
||||
|
||||
// Verify the body is an empty array
|
||||
if mockWriter.body == nil {
|
||||
t.Error("Expected body to be set, got nil")
|
||||
} else {
|
||||
bodyBytes, err := json.Marshal(mockWriter.body)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal body: %v", err)
|
||||
}
|
||||
// The body should be wrapped in a Response object with "data" field
|
||||
bodyStr := string(bodyBytes)
|
||||
if !strings.Contains(bodyStr, `"data":[]`) && !strings.Contains(bodyStr, `"result":[]`) {
|
||||
t.Errorf("Expected body to contain empty array, got: %s", bodyStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that sendResponseWithOptions adds X-No-Data-Found header
|
||||
func TestSendResponseWithOptions_NoDataFoundHeader(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Mock ResponseWriter
|
||||
mockWriter := &MockTestResponseWriter{
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{}
|
||||
options := &ExtendedRequestOptions{}
|
||||
|
||||
// Test with nil data
|
||||
handler.sendResponseWithOptions(mockWriter, nil, metadata, options)
|
||||
|
||||
// Check if X-No-Data-Found header was set
|
||||
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
||||
}
|
||||
|
||||
// Check status code is 200
|
||||
if mockWriter.statusCode != 200 {
|
||||
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
|
||||
}
|
||||
|
||||
// Verify the body is an empty array
|
||||
if mockWriter.body == nil {
|
||||
t.Error("Expected body to be set, got nil")
|
||||
} else {
|
||||
bodyBytes, err := json.Marshal(mockWriter.body)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to marshal body: %v", err)
|
||||
}
|
||||
bodyStr := string(bodyBytes)
|
||||
if bodyStr != "[]" {
|
||||
t.Errorf("Expected body to be '[]', got: %s", bodyStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MockTestResponseWriter for testing
|
||||
type MockTestResponseWriter struct {
|
||||
headers map[string]string
|
||||
statusCode int
|
||||
body interface{}
|
||||
}
|
||||
|
||||
func (m *MockTestResponseWriter) SetHeader(key, value string) {
|
||||
m.headers[key] = value
|
||||
}
|
||||
|
||||
func (m *MockTestResponseWriter) WriteHeader(statusCode int) {
|
||||
m.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (m *MockTestResponseWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (m *MockTestResponseWriter) WriteJSON(data interface{}) error {
|
||||
m.body = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTestResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmptyTestModel for testing
|
||||
type EmptyTestModel struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -300,6 +301,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Options: options,
|
||||
ID: id,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
@@ -480,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
if options.CustomSQLWhere != "" {
|
||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// First add table prefixes to unqualified columns (but skip columns inside function calls)
|
||||
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -490,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (OR condition)
|
||||
if options.CustomSQLOr != "" {
|
||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
@@ -512,14 +517,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
direction = "DESC"
|
||||
}
|
||||
logger.Debug("Applying sort: %s %s", sort.Column, direction)
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
} else {
|
||||
// Regular column - let Bun handle quoting
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
}
|
||||
|
||||
// Get total count before pagination (unless skip count is requested)
|
||||
var total int
|
||||
if !options.SkipCount {
|
||||
// Try to get from cache first (unless SkipCache is true)
|
||||
var cachedTotal *cache.CachedTotal
|
||||
var cachedTotalData *cachedTotal
|
||||
var cacheKey string
|
||||
|
||||
if !options.SkipCache {
|
||||
@@ -533,7 +546,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
|
||||
cacheKeyHash := buildExtendedQueryCacheKey(
|
||||
tableName,
|
||||
options.Filters,
|
||||
options.Sort,
|
||||
@@ -544,22 +557,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
options.CursorForward,
|
||||
options.CursorBackward,
|
||||
)
|
||||
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
|
||||
cacheKey = getQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Try to retrieve from cache
|
||||
cachedTotal = &cache.CachedTotal{}
|
||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
|
||||
cachedTotalData = &cachedTotal{}
|
||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotalData)
|
||||
if err == nil {
|
||||
total = cachedTotal.Total
|
||||
total = cachedTotalData.Total
|
||||
logger.Debug("Total records (from cache): %d", total)
|
||||
} else {
|
||||
logger.Debug("Cache miss for query total")
|
||||
cachedTotal = nil
|
||||
cachedTotalData = nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not in cache or cache skip, execute count query
|
||||
if cachedTotal == nil {
|
||||
if cachedTotalData == nil {
|
||||
count, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error counting records: %v", err)
|
||||
@@ -569,11 +582,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
total = count
|
||||
logger.Debug("Total records (from query): %d", total)
|
||||
|
||||
// Store in cache (if caching is enabled)
|
||||
// Store in cache with schema and table tags (if caching is enabled)
|
||||
if !options.SkipCache && cacheKey != "" {
|
||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||
cacheData := &cache.CachedTotal{Total: total}
|
||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
||||
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||
logger.Warn("Failed to cache query total: %v", err)
|
||||
// Don't fail the request if caching fails
|
||||
} else {
|
||||
@@ -652,6 +664,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
return
|
||||
}
|
||||
|
||||
// Check if a specific ID was requested but no record was found
|
||||
resultCount := reflection.Len(modelPtr)
|
||||
if id != "" && resultCount == 0 {
|
||||
logger.Warn("Record not found for ID: %s", id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil)
|
||||
return
|
||||
}
|
||||
|
||||
limit := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
@@ -666,7 +686,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Count: int64(reflection.Len(modelPtr)),
|
||||
Count: int64(resultCount),
|
||||
Filtered: int64(total),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
@@ -745,9 +765,42 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
|
||||
// Apply ComputedQL fields if any
|
||||
if len(preload.ComputedQL) > 0 {
|
||||
// Get the base table name from the related model
|
||||
baseTableName := getTableNameFromModel(relatedModel)
|
||||
|
||||
// Convert the preload relation path to the appropriate alias format
|
||||
// This is ORM-specific. Currently we only support Bun's format.
|
||||
// TODO: Add support for other ORMs if needed
|
||||
preloadAlias := ""
|
||||
if h.db.GetUnderlyingDB() != nil {
|
||||
// Check if we're using Bun by checking the type name
|
||||
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
||||
if strings.Contains(underlyingType, "bun.DB") {
|
||||
// Use Bun's alias format: lowercase with double underscores
|
||||
preloadAlias = relationPathToBunAlias(preload.Relation)
|
||||
}
|
||||
// For GORM: GORM doesn't use the same alias format, and this fix
|
||||
// may not be needed since GORM handles preloads differently
|
||||
}
|
||||
|
||||
logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)",
|
||||
preload.Relation, preloadAlias, baseTableName)
|
||||
|
||||
for colName, colExpr := range preload.ComputedQL {
|
||||
// Replace table references in the expression with the preload alias
|
||||
// This fixes the ambiguous column reference issue when there are multiple
|
||||
// levels of recursive/nested preloads
|
||||
adjustedExpr := colExpr
|
||||
if baseTableName != "" && preloadAlias != "" {
|
||||
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||
if adjustedExpr != colExpr {
|
||||
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||
colName, colExpr, adjustedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName))
|
||||
// Remove the computed column from selected columns to avoid duplication
|
||||
for colIndex := range preload.Columns {
|
||||
if preload.Columns[colIndex] == colName {
|
||||
@@ -793,7 +846,14 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
// Apply sorting
|
||||
if len(preload.Sort) > 0 {
|
||||
for _, sort := range preload.Sort {
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||
sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
} else {
|
||||
// Regular column - let ORM handle quoting
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -801,7 +861,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
if len(preload.Where) > 0 {
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
// Then sanitize and allow preload table prefixes
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -840,6 +903,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
return query
|
||||
}
|
||||
|
||||
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
||||
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
||||
func relationPathToBunAlias(relationPath string) string {
|
||||
if relationPath == "" {
|
||||
return ""
|
||||
}
|
||||
// Convert to lowercase and replace dots with double underscores
|
||||
alias := strings.ToLower(relationPath)
|
||||
alias = strings.ReplaceAll(alias, ".", "__")
|
||||
return alias
|
||||
}
|
||||
|
||||
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||
// with the appropriate alias for the current preload level
|
||||
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||
return sqlExpr
|
||||
}
|
||||
|
||||
// Replace both quoted and unquoted table references
|
||||
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||
|
||||
// Pattern 1: tablename.column (unquoted)
|
||||
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||
|
||||
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getTableNameFromModel extracts the table name from a model
|
||||
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
||||
func getTableNameFromModel(model interface{}) string {
|
||||
if model == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers
|
||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Look for bun tag on embedded BaseModel
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if field.Anonymous {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.HasPrefix(bunTag, "table:") {
|
||||
return strings.TrimPrefix(bunTag, "table:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||
return strings.ToLower(modelType.Name())
|
||||
}
|
||||
|
||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||
// Capture panics and return error response
|
||||
defer func() {
|
||||
@@ -866,6 +996,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
Options: options,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
@@ -955,6 +1086,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
Data: modelValue,
|
||||
Writer: w,
|
||||
Query: query,
|
||||
Tx: tx,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
||||
@@ -1022,6 +1154,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||
}
|
||||
|
||||
@@ -1047,6 +1184,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Tx: h.db,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
@@ -1116,12 +1254,24 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
// Ensure ID is in the data map for the update
|
||||
dataMap[pkName] = targetID
|
||||
|
||||
// Create update query
|
||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
||||
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
||||
// Get the type of the model, handling both pointer and non-pointer types
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
modelInstance := reflect.New(modelType).Interface()
|
||||
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
||||
return fmt.Errorf("failed to populate model from data: %w", err)
|
||||
}
|
||||
|
||||
// Create update query using Model() to preserve custom types and driver.Valuer interfaces
|
||||
query := tx.NewUpdate().Model(modelInstance)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Query = query
|
||||
hookCtx.Tx = tx
|
||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||
}
|
||||
@@ -1180,6 +1330,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||
}
|
||||
|
||||
@@ -1217,6 +1372,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
Model: model,
|
||||
ID: itemID,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
@@ -1247,6 +1403,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", deletedCount)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||
return
|
||||
|
||||
@@ -1285,6 +1446,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
Model: model,
|
||||
ID: itemIDStr,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
@@ -1314,6 +1476,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", deletedCount)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||
return
|
||||
|
||||
@@ -1337,6 +1504,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
Model: model,
|
||||
ID: itemIDStr,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
@@ -1367,6 +1535,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully deleted %d records", deletedCount)
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||
return
|
||||
|
||||
@@ -1380,7 +1553,34 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
// Single delete with URL ID
|
||||
// Execute BeforeDelete hooks
|
||||
if id == "" {
|
||||
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for delete", nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Get primary key name
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// First, fetch the record that will be deleted
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
recordToDelete := reflect.New(modelType).Interface()
|
||||
|
||||
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Warn("Record not found for delete: %s = %s", pkName, id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", err)
|
||||
return
|
||||
}
|
||||
logger.Error("Error fetching record for delete: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Error fetching record", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute BeforeDelete hooks with the record data
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
@@ -1390,6 +1590,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
Model: model,
|
||||
ID: id,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
Data: recordToDelete,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
@@ -1399,13 +1601,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
query := h.db.NewDelete().Table(tableName)
|
||||
|
||||
if id == "" {
|
||||
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for delete", nil)
|
||||
return
|
||||
}
|
||||
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Query = query
|
||||
@@ -1427,11 +1623,15 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Execute AfterDelete hooks
|
||||
responseData := map[string]interface{}{
|
||||
"deleted": result.RowsAffected(),
|
||||
// Check if the record was actually deleted
|
||||
if result.RowsAffected() == 0 {
|
||||
logger.Warn("No rows deleted for ID: %s", id)
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found or already deleted", nil)
|
||||
return
|
||||
}
|
||||
hookCtx.Result = responseData
|
||||
|
||||
// Execute AfterDelete hooks with the deleted record data
|
||||
hookCtx.Result = recordToDelete
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||
@@ -1440,7 +1640,13 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, responseData, nil)
|
||||
// Return the deleted record data
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, recordToDelete, nil)
|
||||
}
|
||||
|
||||
// mergeRecordWithRequest merges a database record with the original request data
|
||||
@@ -1936,14 +2142,30 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
|
||||
|
||||
// sendResponseWithOptions sends a response with optional formatting
|
||||
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
||||
w.SetHeader("Content-Type", "application/json")
|
||||
|
||||
// Handle nil data - convert to empty array
|
||||
if data == nil {
|
||||
data = []interface{}{}
|
||||
}
|
||||
|
||||
// Calculate data length after nil conversion
|
||||
dataLen := reflection.Len(data)
|
||||
|
||||
// Add X-No-Data-Found header when no records were found
|
||||
if dataLen == 0 {
|
||||
w.SetHeader("X-No-Data-Found", "true")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Normalize single-record arrays to objects if requested
|
||||
if options != nil && options.SingleRecordAsObject {
|
||||
data = h.normalizeResultArray(data)
|
||||
}
|
||||
|
||||
// Return data as-is without wrapping in common.Response
|
||||
w.SetHeader("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := w.WriteJSON(data); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
@@ -1953,7 +2175,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||
if data == nil {
|
||||
return nil
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
// Use reflection to check if data is a slice or array
|
||||
@@ -1962,18 +2184,50 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||
dataValue = dataValue.Elem()
|
||||
}
|
||||
|
||||
// Check if it's a slice or array with exactly one element
|
||||
if (dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array) && dataValue.Len() == 1 {
|
||||
// Return the single element
|
||||
return dataValue.Index(0).Interface()
|
||||
// Check if it's a slice or array
|
||||
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||
if dataValue.Len() == 1 {
|
||||
// Return the single element
|
||||
return dataValue.Index(0).Interface()
|
||||
} else if dataValue.Len() == 0 {
|
||||
// Keep empty array as empty array, don't convert to empty object
|
||||
return []interface{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if dataValue.Kind() == reflect.String {
|
||||
str := dataValue.String()
|
||||
if str == "" || str == "null" {
|
||||
return []interface{}{}
|
||||
}
|
||||
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// sendFormattedResponse sends response with formatting options
|
||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||
// Normalize single-record arrays to objects if requested
|
||||
httpStatus := http.StatusOK
|
||||
|
||||
// Handle nil data - convert to empty array
|
||||
if data == nil {
|
||||
data = []interface{}{}
|
||||
}
|
||||
|
||||
// Calculate data length after nil conversion
|
||||
// Note: This is done BEFORE normalization because X-No-Data-Found indicates
|
||||
// whether data was found in the database, not the final response format
|
||||
dataLen := reflection.Len(data)
|
||||
|
||||
// Add X-No-Data-Found header when no records were found
|
||||
if dataLen == 0 {
|
||||
w.SetHeader("X-No-Data-Found", "true")
|
||||
}
|
||||
|
||||
// Apply normalization after header is set
|
||||
// normalizeResultArray may convert single-element arrays to objects,
|
||||
// but the X-No-Data-Found header reflects the original query result
|
||||
if options.SingleRecordAsObject {
|
||||
data = h.normalizeResultArray(data)
|
||||
}
|
||||
@@ -1992,7 +2246,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
||||
switch options.ResponseFormat {
|
||||
case "simple":
|
||||
// Simple format: just return the data array
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.WriteHeader(httpStatus)
|
||||
if err := w.WriteJSON(data); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
@@ -2004,7 +2258,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
||||
if metadata != nil {
|
||||
response["count"] = metadata.Total
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.WriteHeader(httpStatus)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
@@ -2015,7 +2269,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
||||
Data: data,
|
||||
Metadata: metadata,
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.WriteHeader(httpStatus)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
@@ -2073,7 +2327,14 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||
|
||||
// Check if it's an expression (enclosed in brackets) - use directly without table prefix
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
} else {
|
||||
// Regular column - add table prefix
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||
}
|
||||
}
|
||||
sortSQL = strings.Join(sortParts, ", ")
|
||||
} else {
|
||||
@@ -2282,6 +2543,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||
// Filter columns using the related model's validator
|
||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||
|
||||
// Filter sort columns in the expand Sort string
|
||||
if expand.Sort != "" {
|
||||
sortFields := strings.Split(expand.Sort, ",")
|
||||
validSortFields := make([]string, 0, len(sortFields))
|
||||
for _, sortField := range sortFields {
|
||||
sortField = strings.TrimSpace(sortField)
|
||||
if sortField == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract column name (remove direction prefixes/suffixes)
|
||||
colName := sortField
|
||||
direction := ""
|
||||
|
||||
if strings.HasPrefix(sortField, "-") {
|
||||
direction = "-"
|
||||
colName = strings.TrimPrefix(sortField, "-")
|
||||
} else if strings.HasPrefix(sortField, "+") {
|
||||
direction = "+"
|
||||
colName = strings.TrimPrefix(sortField, "+")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.ToLower(colName), " desc") {
|
||||
direction = " desc"
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
|
||||
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
|
||||
direction = " asc"
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
|
||||
}
|
||||
|
||||
colName = strings.TrimSpace(colName)
|
||||
|
||||
// Validate the column name
|
||||
if expandValidator.IsValidColumn(colName) {
|
||||
validSortFields = append(validSortFields, direction+colName)
|
||||
} else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if common.IsSafeSortExpression(colName) {
|
||||
validSortFields = append(validSortFields, direction+colName)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName)
|
||||
}
|
||||
}
|
||||
filteredExpand.Sort = strings.Join(validSortFields, ",")
|
||||
}
|
||||
} else {
|
||||
// If we can't find the relationship, log a warning and skip column filtering
|
||||
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||
|
||||
@@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
||||
}
|
||||
|
||||
// parseCommaSeparated parses comma-separated values and trims whitespace
|
||||
// It respects bracket nesting and only splits on commas outside of parentheses
|
||||
func (h *Handler) parseCommaSeparated(value string) []string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
result := make([]string, 0)
|
||||
var current strings.Builder
|
||||
nestingLevel := 0
|
||||
|
||||
for _, char := range value {
|
||||
switch char {
|
||||
case '(':
|
||||
nestingLevel++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
nestingLevel--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if nestingLevel == 0 {
|
||||
// We're outside all brackets, so split here
|
||||
part := strings.TrimSpace(current.String())
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
current.Reset()
|
||||
} else {
|
||||
// Inside brackets, keep the comma
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
part := strings.TrimSpace(current.String())
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -907,7 +935,16 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
// Add WHERE clause if SQL conditions specified
|
||||
whereConditions := make([]string, 0)
|
||||
if len(xfile.SqlAnd) > 0 {
|
||||
whereConditions = append(whereConditions, xfile.SqlAnd...)
|
||||
// Process each SQL condition: add table prefixes and sanitize
|
||||
for _, sqlCond := range xfile.SqlAnd {
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
|
||||
// Then sanitize the condition
|
||||
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
|
||||
if sanitizedCond != "" {
|
||||
whereConditions = append(whereConditions, sanitizedCond)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(whereConditions) > 0 {
|
||||
preloadOpt.Where = strings.Join(whereConditions, " AND ")
|
||||
|
||||
@@ -55,6 +55,10 @@ type HookContext struct {
|
||||
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
|
||||
@@ -150,6 +150,50 @@ func ExampleRelatedDataHook(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleTxHook demonstrates using the Tx field to execute additional SQL queries
|
||||
// The Tx field provides access to the database/transaction for custom queries
|
||||
func ExampleTxHook(ctx *HookContext) error {
|
||||
// Example: Execute additional SQL operations alongside the main query
|
||||
// This is useful for maintaining data consistency, updating related records, etc.
|
||||
|
||||
if ctx.Entity == "orders" && ctx.Data != nil {
|
||||
// Example: Update inventory when an order is created
|
||||
// Extract product ID and quantity from the order data
|
||||
// dataMap, ok := ctx.Data.(map[string]interface{})
|
||||
// if !ok {
|
||||
// return fmt.Errorf("invalid data format")
|
||||
// }
|
||||
// productID := dataMap["product_id"]
|
||||
// quantity := dataMap["quantity"]
|
||||
|
||||
// Use ctx.Tx to execute additional SQL queries
|
||||
// The Tx field contains the same database/transaction as the main operation
|
||||
// If inside a transaction, your queries will be part of the same transaction
|
||||
// query := ctx.Tx.NewUpdate().
|
||||
// Table("inventory").
|
||||
// Set("quantity = quantity - ?", quantity).
|
||||
// Where("product_id = ?", productID)
|
||||
//
|
||||
// if _, err := query.Exec(ctx.Context); err != nil {
|
||||
// logger.Error("Failed to update inventory: %v", err)
|
||||
// return fmt.Errorf("failed to update inventory: %w", err)
|
||||
// }
|
||||
|
||||
// You can also execute raw SQL using ctx.Tx
|
||||
// var result []map[string]interface{}
|
||||
// err := ctx.Tx.Query(ctx.Context, &result,
|
||||
// "INSERT INTO order_history (order_id, status) VALUES (?, ?)",
|
||||
// orderID, "pending")
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to insert order history: %w", err)
|
||||
// }
|
||||
|
||||
logger.Debug("Executed additional SQL for order entity")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupExampleHooks demonstrates how to register hooks on a handler
|
||||
func SetupExampleHooks(handler *Handler) {
|
||||
hooks := handler.Hooks()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package restheadspec
|
||||
@@ -21,12 +22,12 @@ import (
|
||||
|
||||
// Test models
|
||||
type TestUser struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||
Age int `json:"age"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||
Age int `json:"age"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
|
||||
}
|
||||
|
||||
@@ -35,13 +36,13 @@ func (TestUser) TableName() string {
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null" json:"user_id"`
|
||||
Title string `gorm:"not null" json:"title"`
|
||||
Content string `json:"content"`
|
||||
Published bool `gorm:"default:false" json:"published"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null" json:"user_id"`
|
||||
Title string `gorm:"not null" json:"title"`
|
||||
Content string `json:"content"`
|
||||
Published bool `gorm:"default:false" json:"published"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
|
||||
}
|
||||
|
||||
@@ -54,7 +55,7 @@ type TestComment struct {
|
||||
PostID uint `gorm:"not null" json:"post_id"`
|
||||
Content string `gorm:"not null" json:"content"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
|
||||
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
|
||||
}
|
||||
|
||||
func (TestComment) TableName() string {
|
||||
@@ -401,7 +402,7 @@ func TestIntegration_GetMetadata(t *testing.T) {
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
@@ -492,7 +493,7 @@ func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
|
||||
@@ -296,7 +296,7 @@ func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName
|
||||
}
|
||||
|
||||
func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (reflect.Value, error) {
|
||||
defer logger.CatchPanic("ApplyColumnSecurity")
|
||||
defer logger.CatchPanic("ApplyColumnSecurity")()
|
||||
|
||||
if m.ColumnSecurity == nil {
|
||||
return records, fmt.Errorf("security not initialized")
|
||||
@@ -437,7 +437,7 @@ func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema
|
||||
}
|
||||
|
||||
func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||
defer logger.CatchPanic("GetRowSecurityTemplate")
|
||||
defer logger.CatchPanic("GetRowSecurityTemplate")()
|
||||
|
||||
if m.RowSecurity == nil {
|
||||
return RowSecurity{}, fmt.Errorf("security not initialized")
|
||||
|
||||
@@ -1,233 +1,314 @@
|
||||
# Server Package
|
||||
|
||||
Graceful HTTP server with request draining and shutdown coordination.
|
||||
Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support.
|
||||
|
||||
## Features
|
||||
|
||||
✅ **Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently
|
||||
✅ **Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining
|
||||
✅ **Automatic Request Rejection** - New requests get 503 during shutdown
|
||||
✅ **Health & Readiness Endpoints** - Kubernetes-ready health checks
|
||||
✅ **Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics)
|
||||
✅ **Comprehensive TLS Support**:
|
||||
- Certificate files (production)
|
||||
- Self-signed certificates (development/testing)
|
||||
- Let's Encrypt / AutoTLS (automatic certificate management)
|
||||
✅ **GZIP Compression** - Optional response compression
|
||||
✅ **Panic Recovery** - Automatic panic recovery middleware
|
||||
✅ **Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Single Server
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
|
||||
// Create server
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: router,
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add server
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: myRouter,
|
||||
GZIP: true,
|
||||
})
|
||||
|
||||
// Start server (blocks until shutdown signal)
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
// Start and wait for shutdown signal
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
### Multiple Servers
|
||||
|
||||
✅ Graceful shutdown on SIGINT/SIGTERM
|
||||
✅ Request draining (waits for in-flight requests)
|
||||
✅ Automatic request rejection during shutdown
|
||||
✅ Health and readiness endpoints
|
||||
✅ Shutdown callbacks for cleanup
|
||||
✅ Configurable timeouts
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Public API
|
||||
mgr.Add(server.Config{
|
||||
Name: "public-api",
|
||||
Port: 8080,
|
||||
Handler: publicRouter,
|
||||
})
|
||||
|
||||
// Admin API
|
||||
mgr.Add(server.Config{
|
||||
Name: "admin-api",
|
||||
Port: 8081,
|
||||
Handler: adminRouter,
|
||||
})
|
||||
|
||||
// Start all and wait
|
||||
mgr.ServeWithGracefulShutdown()
|
||||
```
|
||||
|
||||
## HTTPS/TLS Configuration
|
||||
|
||||
### Option 1: Certificate Files (Production)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "https-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
SSLCert: "/etc/ssl/certs/server.crt",
|
||||
SSLKey: "/etc/ssl/private/server.key",
|
||||
})
|
||||
```
|
||||
|
||||
### Option 2: Self-Signed Certificate (Development)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "dev-server",
|
||||
Host: "localhost",
|
||||
Port: 8443,
|
||||
Handler: handler,
|
||||
SelfSignedSSL: true, // Auto-generates certificate
|
||||
})
|
||||
```
|
||||
|
||||
### Option 3: Let's Encrypt / AutoTLS (Production)
|
||||
|
||||
```go
|
||||
mgr.Add(server.Config{
|
||||
Name: "prod-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache", // Certificate cache directory
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
config := server.Config{
|
||||
// Server address
|
||||
Addr: ":8080",
|
||||
server.Config{
|
||||
// Basic configuration
|
||||
Name: "my-server", // Server name (required)
|
||||
Host: "0.0.0.0", // Bind address
|
||||
Port: 8080, // Port (required)
|
||||
Handler: myRouter, // HTTP handler (required)
|
||||
Description: "My API server", // Optional description
|
||||
|
||||
// HTTP handler
|
||||
Handler: myRouter,
|
||||
// Features
|
||||
GZIP: true, // Enable GZIP compression
|
||||
|
||||
// Maximum time for graceful shutdown (default: 30s)
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
// TLS/HTTPS (choose one option)
|
||||
SSLCert: "/path/to/cert.pem", // Certificate file
|
||||
SSLKey: "/path/to/key.pem", // Key file
|
||||
SelfSignedSSL: false, // Auto-generate self-signed cert
|
||||
AutoTLS: false, // Let's Encrypt
|
||||
AutoTLSDomains: []string{}, // Domains for AutoTLS
|
||||
AutoTLSEmail: "", // Email for Let's Encrypt
|
||||
AutoTLSCacheDir: "./certs-cache", // Cert cache directory
|
||||
|
||||
// Time to wait for in-flight requests (default: 25s)
|
||||
DrainTimeout: 25 * time.Second,
|
||||
|
||||
// Request read timeout (default: 10s)
|
||||
ReadTimeout: 10 * time.Second,
|
||||
|
||||
// Response write timeout (default: 10s)
|
||||
WriteTimeout: 10 * time.Second,
|
||||
|
||||
// Idle connection timeout (default: 120s)
|
||||
IdleTimeout: 120 * time.Second,
|
||||
// Timeouts
|
||||
ShutdownTimeout: 30 * time.Second, // Max shutdown time
|
||||
DrainTimeout: 25 * time.Second, // Request drain timeout
|
||||
ReadTimeout: 15 * time.Second, // Request read timeout
|
||||
WriteTimeout: 15 * time.Second, // Response write timeout
|
||||
IdleTimeout: 60 * time.Second, // Idle connection timeout
|
||||
}
|
||||
|
||||
srv := server.NewGracefulServer(config)
|
||||
```
|
||||
|
||||
## Shutdown Behavior
|
||||
## Graceful Shutdown
|
||||
|
||||
**Signal received (SIGINT/SIGTERM):**
|
||||
### Automatic (Recommended)
|
||||
|
||||
1. **Mark as shutting down** - New requests get 503
|
||||
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
|
||||
3. **Shutdown server** - Close listeners and connections
|
||||
4. **Execute callbacks** - Run registered cleanup functions
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add servers...
|
||||
|
||||
// This blocks until SIGINT/SIGTERM
|
||||
mgr.ServeWithGracefulShutdown()
|
||||
```
|
||||
Time Event
|
||||
─────────────────────────────────────────
|
||||
0s Signal received: SIGTERM
|
||||
├─ Mark as shutting down
|
||||
├─ Reject new requests (503)
|
||||
└─ Start draining...
|
||||
|
||||
1s In-flight: 50 requests
|
||||
2s In-flight: 32 requests
|
||||
3s In-flight: 12 requests
|
||||
4s In-flight: 3 requests
|
||||
5s In-flight: 0 requests ✓
|
||||
└─ All requests drained
|
||||
### Manual Control
|
||||
|
||||
5s Execute shutdown callbacks
|
||||
6s Shutdown complete
|
||||
```go
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Add and start servers
|
||||
mgr.StartAll()
|
||||
|
||||
// Later... stop gracefully
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := mgr.StopAllWithContext(ctx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Shutdown Callbacks
|
||||
|
||||
Register cleanup functions to run during shutdown:
|
||||
|
||||
```go
|
||||
// Close database
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Closing database...")
|
||||
return db.Close()
|
||||
})
|
||||
|
||||
// Flush metrics
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Flushing metrics...")
|
||||
return metrics.Flush(ctx)
|
||||
})
|
||||
|
||||
// Close cache
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Closing cache...")
|
||||
return cache.Close()
|
||||
})
|
||||
```
|
||||
|
||||
## Health Checks
|
||||
|
||||
### Health Endpoint
|
||||
|
||||
Returns 200 when healthy, 503 when shutting down:
|
||||
### Adding Health Endpoints
|
||||
|
||||
```go
|
||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||
instance, _ := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Port: 8080,
|
||||
Handler: router,
|
||||
})
|
||||
|
||||
// Add health endpoints to your router
|
||||
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
```
|
||||
|
||||
**Response (healthy):**
|
||||
### Health Endpoint
|
||||
|
||||
Returns server health status:
|
||||
|
||||
**Healthy (200 OK):**
|
||||
```json
|
||||
{"status":"healthy"}
|
||||
```
|
||||
|
||||
**Response (shutting down):**
|
||||
**Shutting Down (503 Service Unavailable):**
|
||||
```json
|
||||
{"status":"shutting_down"}
|
||||
```
|
||||
|
||||
### Readiness Endpoint
|
||||
|
||||
Includes in-flight request count:
|
||||
Returns readiness with in-flight request count:
|
||||
|
||||
```go
|
||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||
```
|
||||
|
||||
**Response:**
|
||||
**Ready (200 OK):**
|
||||
```json
|
||||
{"ready":true,"in_flight_requests":12}
|
||||
```
|
||||
|
||||
**During shutdown:**
|
||||
**Not Ready (503 Service Unavailable):**
|
||||
```json
|
||||
{"ready":false,"reason":"shutting_down"}
|
||||
```
|
||||
|
||||
## Shutdown Callbacks
|
||||
## Shutdown Behavior
|
||||
|
||||
Register cleanup functions to run during shutdown:
|
||||
When a shutdown signal (SIGINT/SIGTERM) is received:
|
||||
|
||||
```go
|
||||
// Close database
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Closing database connection...")
|
||||
return db.Close()
|
||||
})
|
||||
1. **Mark as shutting down** → New requests get 503
|
||||
2. **Execute callbacks** → Run cleanup functions
|
||||
3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests
|
||||
4. **Shutdown servers** → Close listeners and connections
|
||||
|
||||
// Flush metrics
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Flushing metrics...")
|
||||
return metricsProvider.Flush(ctx)
|
||||
})
|
||||
```
|
||||
Time Event
|
||||
─────────────────────────────────────────
|
||||
0s Signal received: SIGTERM
|
||||
├─ Mark servers as shutting down
|
||||
├─ Reject new requests (503)
|
||||
└─ Execute shutdown callbacks
|
||||
|
||||
// Close cache
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
logger.Info("Closing cache...")
|
||||
return cache.Close()
|
||||
})
|
||||
1s Callbacks complete
|
||||
└─ Start draining requests...
|
||||
|
||||
2s In-flight: 50 requests
|
||||
3s In-flight: 32 requests
|
||||
4s In-flight: 12 requests
|
||||
5s In-flight: 3 requests
|
||||
6s In-flight: 0 requests ✓
|
||||
└─ All requests drained
|
||||
|
||||
6s Shutdown servers
|
||||
7s All servers stopped ✓
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
## Server Management
|
||||
|
||||
### Get Server Instance
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize metrics
|
||||
metricsProvider := metrics.NewPrometheusProvider()
|
||||
metrics.SetProvider(metricsProvider)
|
||||
|
||||
// Create router
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Apply middleware
|
||||
rateLimiter := middleware.NewRateLimiter(100, 20)
|
||||
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
|
||||
router.Use(rateLimiter.Middleware)
|
||||
router.Use(sizeLimiter.Middleware)
|
||||
router.Use(sanitizer.Middleware)
|
||||
router.Use(metricsProvider.Middleware)
|
||||
|
||||
// API routes
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
|
||||
// Create graceful server
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: router,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
DrainTimeout: 25 * time.Second,
|
||||
})
|
||||
|
||||
// Health checks
|
||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
||||
|
||||
// Metrics endpoint
|
||||
router.Handle("/metrics", metricsProvider.Handler())
|
||||
|
||||
// Register shutdown callbacks
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Flushing metrics...")
|
||||
return nil
|
||||
})
|
||||
|
||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Closing database...")
|
||||
// return db.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start server (blocks until shutdown)
|
||||
log.Printf("Starting server on :8080")
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for shutdown to complete
|
||||
srv.Wait()
|
||||
log.Println("Server stopped")
|
||||
instance, err := mgr.Get("api-server")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Your handler logic
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message":"success"}`))
|
||||
// Check status
|
||||
fmt.Printf("Address: %s\n", instance.Addr())
|
||||
fmt.Printf("Name: %s\n", instance.Name())
|
||||
fmt.Printf("In-flight: %d\n", instance.InFlightRequests())
|
||||
fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown())
|
||||
```
|
||||
|
||||
### List All Servers
|
||||
|
||||
```go
|
||||
instances := mgr.List()
|
||||
for _, instance := range instances {
|
||||
fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr())
|
||||
}
|
||||
```
|
||||
|
||||
### Remove Server
|
||||
|
||||
```go
|
||||
// Stop and remove a server
|
||||
if err := mgr.Remove("api-server"); err != nil {
|
||||
log.Printf("Error removing server: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Restart All Servers
|
||||
|
||||
```go
|
||||
// Gracefully restart all servers
|
||||
if err := mgr.RestartAll(); err != nil {
|
||||
log.Printf("Error restarting: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
@@ -250,23 +331,21 @@ spec:
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
|
||||
# Liveness probe - is app running?
|
||||
# Liveness probe
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8080
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
timeoutSeconds: 5
|
||||
|
||||
# Readiness probe - can app handle traffic?
|
||||
# Readiness probe
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /ready
|
||||
port: 8080
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
timeoutSeconds: 3
|
||||
|
||||
# Graceful shutdown
|
||||
lifecycle:
|
||||
@@ -274,26 +353,12 @@ spec:
|
||||
exec:
|
||||
command: ["/bin/sh", "-c", "sleep 5"]
|
||||
|
||||
# Environment
|
||||
env:
|
||||
- name: SHUTDOWN_TIMEOUT
|
||||
value: "30"
|
||||
```
|
||||
|
||||
### Service
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: myapp
|
||||
spec:
|
||||
selector:
|
||||
app: myapp
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 8080
|
||||
type: LoadBalancer
|
||||
# Allow time for graceful shutdown
|
||||
terminationGracePeriodSeconds: 35
|
||||
```
|
||||
|
||||
## Docker Compose
|
||||
@@ -312,8 +377,70 @@ services:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
stop_grace_period: 35s # Slightly longer than shutdown timeout
|
||||
stop_grace_period: 35s
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Register shutdown callbacks
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
log.Println("Cleanup: Closing database...")
|
||||
// return db.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create router
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
|
||||
// Add server
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Handler: router,
|
||||
GZIP: true,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
DrainTimeout: 25 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Add health endpoints
|
||||
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
// Start and wait for shutdown
|
||||
log.Println("Starting server on :8080")
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
log.Printf("Server stopped: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server shutdown complete")
|
||||
}
|
||||
|
||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message":"success"}`))
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Graceful Shutdown
|
||||
@@ -330,7 +457,7 @@ SERVER_PID=$!
|
||||
# Wait for server to start
|
||||
sleep 2
|
||||
|
||||
# Send some requests
|
||||
# Send requests
|
||||
for i in {1..10}; do
|
||||
curl http://localhost:8080/api/data &
|
||||
done
|
||||
@@ -341,7 +468,7 @@ sleep 1
|
||||
# Send shutdown signal
|
||||
kill -TERM $SERVER_PID
|
||||
|
||||
# Try to send more requests (should get 503)
|
||||
# Try more requests (should get 503)
|
||||
curl -v http://localhost:8080/api/data
|
||||
|
||||
# Wait for server to stop
|
||||
@@ -349,101 +476,13 @@ wait $SERVER_PID
|
||||
echo "Server stopped gracefully"
|
||||
```
|
||||
|
||||
### Expected Output
|
||||
|
||||
```
|
||||
Starting server on :8080
|
||||
Received signal: terminated, initiating graceful shutdown
|
||||
Starting graceful shutdown...
|
||||
Waiting for 8 in-flight requests to complete...
|
||||
Waiting for 4 in-flight requests to complete...
|
||||
Waiting for 1 in-flight requests to complete...
|
||||
All requests drained in 2.3s
|
||||
Cleanup: Flushing metrics...
|
||||
Cleanup: Closing database...
|
||||
Shutting down HTTP server...
|
||||
Graceful shutdown complete
|
||||
Server stopped
|
||||
```
|
||||
|
||||
## Monitoring In-Flight Requests
|
||||
|
||||
```go
|
||||
// Get current in-flight count
|
||||
count := srv.InFlightRequests()
|
||||
fmt.Printf("In-flight requests: %d\n", count)
|
||||
|
||||
// Check if shutting down
|
||||
if srv.IsShuttingDown() {
|
||||
fmt.Println("Server is shutting down")
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Shutdown Logic
|
||||
|
||||
```go
|
||||
// Implement custom shutdown
|
||||
go func() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
log.Println("Shutdown signal received")
|
||||
|
||||
// Custom pre-shutdown logic
|
||||
log.Println("Running custom cleanup...")
|
||||
|
||||
// Shutdown with callbacks
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
srv.server.ListenAndServe()
|
||||
```
|
||||
|
||||
### Multiple Servers
|
||||
|
||||
```go
|
||||
// HTTP server
|
||||
httpSrv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8080",
|
||||
Handler: httpRouter,
|
||||
})
|
||||
|
||||
// HTTPS server
|
||||
httpsSrv := server.NewGracefulServer(server.Config{
|
||||
Addr: ":8443",
|
||||
Handler: httpsRouter,
|
||||
})
|
||||
|
||||
// Start both
|
||||
go httpSrv.ListenAndServe()
|
||||
go httpsSrv.ListenAndServe()
|
||||
|
||||
// Shutdown both on signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
ctx := context.Background()
|
||||
httpSrv.Shutdown(ctx)
|
||||
httpsSrv.Shutdown(ctx)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Set appropriate timeouts**
|
||||
- `DrainTimeout` < `ShutdownTimeout`
|
||||
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
|
||||
|
||||
2. **Register cleanup callbacks** for:
|
||||
2. **Use shutdown callbacks** for:
|
||||
- Database connections
|
||||
- Message queues
|
||||
- Metrics flushing
|
||||
@@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx)
|
||||
- Set `preStop` hook in Kubernetes (5-10s delay)
|
||||
- Allows load balancer to deregister before shutdown
|
||||
|
||||
5. **Monitoring**
|
||||
5. **HTTPS in production**
|
||||
- Use AutoTLS for public-facing services
|
||||
- Use certificate files for enterprise PKI
|
||||
- Use self-signed only for development/testing
|
||||
|
||||
6. **Monitoring**
|
||||
- Track in-flight requests in metrics
|
||||
- Alert on slow drains
|
||||
- Monitor shutdown duration
|
||||
@@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx)
|
||||
```go
|
||||
// Increase drain timeout
|
||||
config.DrainTimeout = 60 * time.Second
|
||||
config.ShutdownTimeout = 65 * time.Second
|
||||
```
|
||||
|
||||
### Requests Still Timing Out
|
||||
### Requests Timing Out
|
||||
|
||||
```go
|
||||
// Increase write timeout
|
||||
config.WriteTimeout = 30 * time.Second
|
||||
```
|
||||
|
||||
### Force Shutdown Not Working
|
||||
|
||||
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
|
||||
|
||||
### Debugging Shutdown
|
||||
### Certificate Issues
|
||||
|
||||
```go
|
||||
// Verify certificate files exist and are readable
|
||||
if _, err := os.Stat(config.SSLCert); err != nil {
|
||||
log.Fatalf("Certificate not found: %v", err)
|
||||
}
|
||||
|
||||
// For AutoTLS, ensure:
|
||||
// - Port 443 is accessible
|
||||
// - Domains resolve to server IP
|
||||
// - Cache directory is writable
|
||||
```
|
||||
|
||||
### Debug Logging
|
||||
|
||||
```go
|
||||
// Enable debug logging
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
|
||||
// Enable debug logging
|
||||
logger.SetLevel("debug")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Manager Methods
|
||||
|
||||
- `NewManager()` - Create new server manager
|
||||
- `Add(cfg Config)` - Register server instance
|
||||
- `Get(name string)` - Get server by name
|
||||
- `Remove(name string)` - Stop and remove server
|
||||
- `StartAll()` - Start all registered servers
|
||||
- `StopAll()` - Stop all servers gracefully
|
||||
- `StopAllWithContext(ctx)` - Stop with timeout
|
||||
- `RestartAll()` - Restart all servers
|
||||
- `List()` - Get all server instances
|
||||
- `ServeWithGracefulShutdown()` - Start and block until shutdown
|
||||
- `RegisterShutdownCallback(cb)` - Register cleanup function
|
||||
|
||||
### Instance Methods
|
||||
|
||||
- `Start()` - Start the server
|
||||
- `Stop(ctx)` - Stop gracefully
|
||||
- `Addr()` - Get server address
|
||||
- `Name()` - Get server name
|
||||
- `HealthCheckHandler()` - Get health handler
|
||||
- `ReadinessHandler()` - Get readiness handler
|
||||
- `InFlightRequests()` - Get in-flight count
|
||||
- `IsShuttingDown()` - Check shutdown status
|
||||
- `Wait()` - Block until shutdown complete
|
||||
|
||||
294
pkg/server/example_test.go
Normal file
294
pkg/server/example_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
)
|
||||
|
||||
// ExampleManager_basic demonstrates basic server manager usage
|
||||
func ExampleManager_basic() {
|
||||
// Create a server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Define a simple handler
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintln(w, "Hello from server!")
|
||||
})
|
||||
|
||||
// Add an HTTP server
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
GZIP: true, // Enable GZIP compression
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Server is now running...
|
||||
// When done, stop gracefully
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_https demonstrates HTTPS configurations
|
||||
func ExampleManager_https() {
|
||||
mgr := server.NewManager()
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Secure connection!")
|
||||
})
|
||||
|
||||
// Option 1: Use certificate files
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "https-server-files",
|
||||
Host: "localhost",
|
||||
Port: 8443,
|
||||
Handler: handler,
|
||||
SSLCert: "/path/to/cert.pem",
|
||||
SSLKey: "/path/to/key.pem",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Option 2: Self-signed certificate (for development)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-self-signed",
|
||||
Host: "localhost",
|
||||
Port: 8444,
|
||||
Handler: handler,
|
||||
SelfSignedSSL: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Option 3: Let's Encrypt / AutoTLS (for production)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
}
|
||||
|
||||
// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks
|
||||
func ExampleManager_gracefulShutdown() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Register shutdown callbacks for cleanup tasks
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
fmt.Println("Closing database connections...")
|
||||
// Close your database here
|
||||
return nil
|
||||
})
|
||||
|
||||
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
fmt.Println("Flushing metrics...")
|
||||
// Flush metrics here
|
||||
return nil
|
||||
})
|
||||
|
||||
// Add server with custom timeouts
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate some work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fmt.Fprintln(w, "Done!")
|
||||
})
|
||||
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
ShutdownTimeout: 30 * time.Second, // Max time for shutdown
|
||||
DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
|
||||
// This will automatically handle graceful shutdown with callbacks
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
fmt.Printf("Shutdown completed: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_healthChecks demonstrates health and readiness endpoints
|
||||
func ExampleManager_healthChecks() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Create a router with health endpoints
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Data endpoint")
|
||||
})
|
||||
|
||||
// Add server
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: mux,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Add health and readiness endpoints
|
||||
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
// Start the server
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Health check returns:
|
||||
// - 200 OK with {"status":"healthy"} when healthy
|
||||
// - 503 Service Unavailable with {"status":"shutting_down"} when shutting down
|
||||
|
||||
// Readiness check returns:
|
||||
// - 200 OK with {"ready":true,"in_flight_requests":N} when ready
|
||||
// - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
}
|
||||
|
||||
// ExampleManager_multipleServers demonstrates running multiple servers
|
||||
func ExampleManager_multipleServers() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Public API server
|
||||
publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Public API")
|
||||
})
|
||||
_, err := mgr.Add(server.Config{
|
||||
Name: "public-api",
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
Handler: publicHandler,
|
||||
GZIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Admin API server (different port)
|
||||
adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Admin API")
|
||||
})
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "admin-api",
|
||||
Host: "localhost",
|
||||
Port: 8081,
|
||||
Handler: adminHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Metrics server (internal only)
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Metrics data")
|
||||
})
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "metrics",
|
||||
Host: "127.0.0.1",
|
||||
Port: 9090,
|
||||
Handler: metricsHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start all servers at once
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Get specific server instance
|
||||
publicInstance, err := mgr.Get("public-api")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
|
||||
|
||||
// List all servers
|
||||
instances := mgr.List()
|
||||
fmt.Printf("Running %d servers\n", len(instances))
|
||||
|
||||
// Stop all servers gracefully (in parallel)
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleManager_monitoring demonstrates monitoring server state
|
||||
func ExampleManager_monitoring() {
|
||||
mgr := server.NewManager()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||
fmt.Fprintln(w, "Done")
|
||||
})
|
||||
|
||||
instance, err := mgr.Add(server.Config{
|
||||
Name: "api-server",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Handler: handler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Check server status
|
||||
fmt.Printf("Server address: %s\n", instance.Addr())
|
||||
fmt.Printf("Server name: %s\n", instance.Name())
|
||||
fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown())
|
||||
fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests())
|
||||
|
||||
// Cleanup
|
||||
mgr.StopAll()
|
||||
|
||||
// Wait for complete shutdown
|
||||
instance.Wait()
|
||||
}
|
||||
137
pkg/server/interfaces.go
Normal file
137
pkg/server/interfaces.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds the configuration for a single web server instance.
|
||||
type Config struct {
|
||||
Name string
|
||||
Host string
|
||||
Port int
|
||||
Description string
|
||||
|
||||
// Handler is the http.Handler (e.g., a router) to be served.
|
||||
Handler http.Handler
|
||||
|
||||
// GZIP compression support
|
||||
GZIP bool
|
||||
|
||||
// TLS/HTTPS configuration options (mutually exclusive)
|
||||
// Option 1: Provide certificate and key files directly
|
||||
SSLCert string
|
||||
SSLKey string
|
||||
|
||||
// Option 2: Use self-signed certificate (for development/testing)
|
||||
// Generates a self-signed certificate automatically if no SSLCert/SSLKey provided
|
||||
SelfSignedSSL bool
|
||||
|
||||
// Option 3: Use Let's Encrypt / Certbot for automatic TLS
|
||||
// AutoTLS enables automatic certificate management via Let's Encrypt
|
||||
AutoTLS bool
|
||||
// AutoTLSDomains specifies the domains for Let's Encrypt certificates
|
||||
AutoTLSDomains []string
|
||||
// AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache")
|
||||
AutoTLSCacheDir string
|
||||
// AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended)
|
||||
AutoTLSEmail string
|
||||
|
||||
// Graceful shutdown configuration
|
||||
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||
// Default: 30 seconds
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||
// before forcing shutdown. Default: 25 seconds
|
||||
DrainTimeout time.Duration
|
||||
|
||||
// ReadTimeout is the maximum duration for reading the entire request
|
||||
// Default: 15 seconds
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||
// Default: 15 seconds
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||
// Default: 60 seconds
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// Instance defines the interface for a single server instance.
|
||||
// It abstracts the underlying http.Server, allowing for easier management and testing.
|
||||
type Instance interface {
|
||||
// Start begins serving requests. This method should be non-blocking and
|
||||
// run the server in a separate goroutine.
|
||||
Start() error
|
||||
|
||||
// Stop gracefully shuts down the server without interrupting any active connections.
|
||||
// It accepts a context to allow for a timeout.
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// Addr returns the network address the server is listening on.
|
||||
Addr() string
|
||||
|
||||
// Name returns the server instance name.
|
||||
Name() string
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks.
|
||||
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down.
|
||||
HealthCheckHandler() http.HandlerFunc
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks.
|
||||
// Includes in-flight request count.
|
||||
ReadinessHandler() http.HandlerFunc
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests.
|
||||
InFlightRequests() int64
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down.
|
||||
IsShuttingDown() bool
|
||||
|
||||
// Wait blocks until shutdown is complete.
|
||||
Wait()
|
||||
}
|
||||
|
||||
// Manager defines the interface for a server manager.
|
||||
// It is responsible for managing the lifecycle of multiple server instances.
|
||||
type Manager interface {
|
||||
// Add registers a new server instance based on the provided configuration.
|
||||
// The server is not started until StartAll or Start is called on the instance.
|
||||
Add(cfg Config) (Instance, error)
|
||||
|
||||
// Get returns a server instance by its name.
|
||||
Get(name string) (Instance, error)
|
||||
|
||||
// Remove stops and removes a server instance by its name.
|
||||
Remove(name string) error
|
||||
|
||||
// StartAll starts all registered server instances that are not already running.
|
||||
StartAll() error
|
||||
|
||||
// StopAll gracefully shuts down all running server instances.
|
||||
// Executes shutdown callbacks and drains in-flight requests.
|
||||
StopAll() error
|
||||
|
||||
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||
StopAllWithContext(ctx context.Context) error
|
||||
|
||||
// RestartAll gracefully restarts all running server instances.
|
||||
RestartAll() error
|
||||
|
||||
// List returns all registered server instances.
|
||||
List() []Instance
|
||||
|
||||
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||
// It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks.
|
||||
ServeWithGracefulShutdown() error
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||
RegisterShutdownCallback(cb ShutdownCallback)
|
||||
}
|
||||
|
||||
// ShutdownCallback is a function called during graceful shutdown.
|
||||
type ShutdownCallback func(context.Context) error
|
||||
601
pkg/server/manager.go
Normal file
601
pkg/server/manager.go
Normal file
@@ -0,0 +1,601 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/gzhttp"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
)
|
||||
|
||||
// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type)
|
||||
type gracefulServer struct {
|
||||
server *http.Server
|
||||
shutdownTimeout time.Duration
|
||||
drainTimeout time.Duration
|
||||
inFlightRequests atomic.Int64
|
||||
isShuttingDown atomic.Bool
|
||||
shutdownOnce sync.Once
|
||||
shutdownComplete chan struct{}
|
||||
}
|
||||
|
||||
// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||
func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if shutting down
|
||||
if gs.isShuttingDown.Load() {
|
||||
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment in-flight counter
|
||||
gs.inFlightRequests.Add(1)
|
||||
defer gs.inFlightRequests.Add(-1)
|
||||
|
||||
// Serve the request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// shutdown performs graceful shutdown with request draining
|
||||
func (gs *gracefulServer) shutdown(ctx context.Context) error {
|
||||
var shutdownErr error
|
||||
|
||||
gs.shutdownOnce.Do(func() {
|
||||
logger.Info("Starting graceful shutdown...")
|
||||
|
||||
// Mark as shutting down (new requests will be rejected)
|
||||
gs.isShuttingDown.Store(true)
|
||||
|
||||
// Create context with timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Wait for in-flight requests to complete (with drain timeout)
|
||||
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
shutdownErr = gs.drainRequests(drainCtx)
|
||||
if shutdownErr != nil {
|
||||
logger.Error("Error draining requests: %v", shutdownErr)
|
||||
}
|
||||
|
||||
// Shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("Error shutting down server: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Graceful shutdown complete")
|
||||
close(gs.shutdownComplete)
|
||||
})
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
// drainRequests waits for in-flight requests to complete
|
||||
func (gs *gracefulServer) drainRequests(ctx context.Context) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
inFlight := gs.inFlightRequests.Load()
|
||||
|
||||
if inFlight == 0 {
|
||||
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||
case <-ticker.C:
|
||||
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inFlightRequests returns the current number of in-flight requests
|
||||
func (gs *gracefulServer) inFlightRequestsCount() int64 {
|
||||
return gs.inFlightRequests.Load()
|
||||
}
|
||||
|
||||
// isShutdown returns true if the server is shutting down
|
||||
func (gs *gracefulServer) isShutdown() bool {
|
||||
return gs.isShuttingDown.Load()
|
||||
}
|
||||
|
||||
// wait blocks until shutdown is complete
|
||||
func (gs *gracefulServer) wait() {
|
||||
<-gs.shutdownComplete
|
||||
}
|
||||
|
||||
// healthCheckHandler returns a handler that responds to health checks
|
||||
func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.isShutdown() {
|
||||
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write health check response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readinessHandler returns a handler for readiness checks
|
||||
func (gs *gracefulServer) readinessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.isShutdown() {
|
||||
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
inFlight := gs.inFlightRequestsCount()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
// serverManager manages a collection of server instances with graceful shutdown support.
|
||||
type serverManager struct {
|
||||
instances map[string]Instance
|
||||
mu sync.RWMutex
|
||||
shutdownCallbacks []ShutdownCallback
|
||||
callbacksMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new server manager.
|
||||
func NewManager() Manager {
|
||||
return &serverManager{
|
||||
instances: make(map[string]Instance),
|
||||
shutdownCallbacks: make([]ShutdownCallback, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Add registers a new server instance.
|
||||
func (sm *serverManager) Add(cfg Config) (Instance, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("server name cannot be empty")
|
||||
}
|
||||
if _, exists := sm.instances[cfg.Name]; exists {
|
||||
return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name)
|
||||
}
|
||||
|
||||
instance, err := newInstance(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sm.instances[cfg.Name] = instance
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Get returns a server instance by its name.
|
||||
func (sm *serverManager) Get(name string) (Instance, error) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
instance, exists := sm.instances[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("server with name '%s' not found", name)
|
||||
}
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Remove stops and removes a server instance by its name.
|
||||
func (sm *serverManager) Remove(name string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
instance, exists := sm.instances[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("server with name '%s' not found", name)
|
||||
}
|
||||
|
||||
// Stop the server if it's running. Prefer the server's configured shutdownTimeout
|
||||
// when available, and fall back to a sensible default.
|
||||
timeout := 10 * time.Second
|
||||
if si, ok := instance.(*serverInstance); ok && si.gracefulServer != nil && si.gracefulServer.shutdownTimeout > 0 {
|
||||
timeout = si.gracefulServer.shutdownTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
if err := instance.Stop(ctx); err != nil {
|
||||
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err)
|
||||
}
|
||||
|
||||
delete(sm.instances, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartAll starts all registered server instances.
|
||||
func (sm *serverManager) StartAll() error {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var startErrors []error
|
||||
for name, instance := range sm.instances {
|
||||
if err := instance.Start(); err != nil {
|
||||
startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(startErrors) > 0 {
|
||||
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopAll gracefully shuts down all running server instances.
|
||||
func (sm *serverManager) StopAll() error {
|
||||
return sm.StopAllWithContext(context.Background())
|
||||
}
|
||||
|
||||
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||
func (sm *serverManager) StopAllWithContext(ctx context.Context) error {
|
||||
sm.mu.RLock()
|
||||
instancesToStop := make([]Instance, 0, len(sm.instances))
|
||||
for _, instance := range sm.instances {
|
||||
instancesToStop = append(instancesToStop, instance)
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
logger.Info("Shutting down all servers...")
|
||||
|
||||
// Execute shutdown callbacks first
|
||||
sm.callbacksMu.Lock()
|
||||
callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks))
|
||||
copy(callbacks, sm.shutdownCallbacks)
|
||||
sm.callbacksMu.Unlock()
|
||||
|
||||
if len(callbacks) > 0 {
|
||||
logger.Info("Executing %d shutdown callbacks...", len(callbacks))
|
||||
for i, cb := range callbacks {
|
||||
if err := cb(ctx); err != nil {
|
||||
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop all instances in parallel
|
||||
var shutdownErrors []error
|
||||
var wg sync.WaitGroup
|
||||
var errorsMu sync.Mutex
|
||||
|
||||
for _, instance := range instancesToStop {
|
||||
wg.Add(1)
|
||||
go func(inst Instance) {
|
||||
defer wg.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
if err := inst.Stop(shutdownCtx); err != nil {
|
||||
errorsMu.Lock()
|
||||
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err))
|
||||
errorsMu.Unlock()
|
||||
}
|
||||
}(instance)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
|
||||
}
|
||||
logger.Info("All servers stopped gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartAll gracefully restarts all running server instances.
|
||||
func (sm *serverManager) RestartAll() error {
|
||||
logger.Info("Restarting all servers...")
|
||||
if err := sm.StopAll(); err != nil {
|
||||
return fmt.Errorf("failed to stop servers during restart: %w", err)
|
||||
}
|
||||
|
||||
// Retry starting all servers with exponential backoff instead of a fixed sleep.
|
||||
const (
|
||||
maxAttempts = 5
|
||||
initialBackoff = 100 * time.Millisecond
|
||||
maxBackoff = 2 * time.Second
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
backoff := initialBackoff
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if err := sm.StartAll(); err != nil {
|
||||
lastErr = err
|
||||
if attempt == maxAttempts {
|
||||
break
|
||||
}
|
||||
logger.Warn("Attempt %d to start servers during restart failed: %v; retrying in %s", attempt, err, backoff)
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("All servers restarted successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to start servers during restart after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// List returns all registered server instances.
|
||||
func (sm *serverManager) List() []Instance {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
instances := make([]Instance, 0, len(sm.instances))
|
||||
for _, instance := range sm.instances {
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
return instances
|
||||
}
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||
func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) {
|
||||
sm.callbacksMu.Lock()
|
||||
defer sm.callbacksMu.Unlock()
|
||||
sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb)
|
||||
}
|
||||
|
||||
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||
func (sm *serverManager) ServeWithGracefulShutdown() error {
|
||||
// Start all servers
|
||||
if err := sm.StartAll(); err != nil {
|
||||
return fmt.Errorf("failed to start servers: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("All servers started. Waiting for shutdown signal...")
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
sig := <-sigChan
|
||||
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||
|
||||
// Create context with timeout for shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return sm.StopAllWithContext(ctx)
|
||||
}
|
||||
|
||||
// serverInstance is a concrete implementation of the Instance interface.
|
||||
// It wraps gracefulServer to provide graceful shutdown capabilities.
|
||||
type serverInstance struct {
|
||||
cfg Config
|
||||
gracefulServer *gracefulServer
|
||||
certFile string // Path to certificate file (may be persistent for self-signed)
|
||||
keyFile string // Path to key file (may be persistent for self-signed)
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
serverErr chan error
|
||||
}
|
||||
|
||||
// newInstance creates a new, unstarted server instance from a config.
|
||||
func newInstance(cfg Config) (*serverInstance, error) {
|
||||
if cfg.Handler == nil {
|
||||
return nil, fmt.Errorf("handler cannot be nil")
|
||||
}
|
||||
|
||||
// Set default timeouts
|
||||
if cfg.ShutdownTimeout == 0 {
|
||||
cfg.ShutdownTimeout = 30 * time.Second
|
||||
}
|
||||
if cfg.DrainTimeout == 0 {
|
||||
cfg.DrainTimeout = 25 * time.Second
|
||||
}
|
||||
if cfg.ReadTimeout == 0 {
|
||||
cfg.ReadTimeout = 15 * time.Second
|
||||
}
|
||||
if cfg.WriteTimeout == 0 {
|
||||
cfg.WriteTimeout = 15 * time.Second
|
||||
}
|
||||
if cfg.IdleTimeout == 0 {
|
||||
cfg.IdleTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
var handler = cfg.Handler
|
||||
|
||||
// Wrap with GZIP handler if enabled
|
||||
if cfg.GZIP {
|
||||
gz, err := gzhttp.NewWrapper()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
|
||||
}
|
||||
handler = gz(handler)
|
||||
}
|
||||
|
||||
// Wrap with the panic recovery middleware
|
||||
handler = middleware.PanicRecovery(handler)
|
||||
|
||||
// Configure TLS if any TLS option is enabled
|
||||
tlsConfig, certFile, keyFile, err := configureTLS(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to configure TLS: %w", err)
|
||||
}
|
||||
|
||||
// Create gracefulServer
|
||||
gracefulSrv := &gracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
shutdownTimeout: cfg.ShutdownTimeout,
|
||||
drainTimeout: cfg.DrainTimeout,
|
||||
shutdownComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
return &serverInstance{
|
||||
cfg: cfg,
|
||||
gracefulServer: gracefulSrv,
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverErr: make(chan error, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins serving requests in a new goroutine.
|
||||
func (s *serverInstance) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
|
||||
}
|
||||
|
||||
// Determine if we're using TLS
|
||||
useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS
|
||||
|
||||
// Wrap handler with request tracking
|
||||
s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
logger.Info("Server '%s' stopped.", s.cfg.Name)
|
||||
}()
|
||||
|
||||
var err error
|
||||
protocol := "HTTP"
|
||||
|
||||
if useTLS {
|
||||
protocol = "HTTPS"
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
|
||||
// For AutoTLS, we need to use a TLS listener
|
||||
if s.cfg.AutoTLS {
|
||||
// Create listener
|
||||
ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr)
|
||||
if lnErr != nil {
|
||||
err = fmt.Errorf("failed to create listener: %w", lnErr)
|
||||
} else {
|
||||
// Wrap with TLS
|
||||
tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig)
|
||||
err = s.gracefulServer.server.Serve(tlsListener)
|
||||
}
|
||||
} else {
|
||||
// Use certificate files (regular SSL or self-signed)
|
||||
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
err = s.gracefulServer.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// If the server stopped for a reason other than a graceful shutdown, log and report the error.
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Server '%s' failed: %v", s.cfg.Name, err)
|
||||
select {
|
||||
case s.serverErr <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.running = true
|
||||
// A small delay to allow the goroutine to start and potentially fail on binding.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check if the server failed to start
|
||||
select {
|
||||
case err := <-s.serverErr:
|
||||
s.running = false
|
||||
return err
|
||||
default:
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the server.
|
||||
func (s *serverInstance) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return nil // Already stopped
|
||||
}
|
||||
|
||||
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
|
||||
err := s.gracefulServer.shutdown(ctx)
|
||||
if err == nil {
|
||||
s.running = false
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the network address the server is listening on.
|
||||
func (s *serverInstance) Addr() string {
|
||||
return s.gracefulServer.server.Addr
|
||||
}
|
||||
|
||||
// Name returns the server instance name.
|
||||
func (s *serverInstance) Name() string {
|
||||
return s.cfg.Name
|
||||
}
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks.
|
||||
func (s *serverInstance) HealthCheckHandler() http.HandlerFunc {
|
||||
return s.gracefulServer.healthCheckHandler()
|
||||
}
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks.
|
||||
func (s *serverInstance) ReadinessHandler() http.HandlerFunc {
|
||||
return s.gracefulServer.readinessHandler()
|
||||
}
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests.
|
||||
func (s *serverInstance) InFlightRequests() int64 {
|
||||
return s.gracefulServer.inFlightRequestsCount()
|
||||
}
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down.
|
||||
func (s *serverInstance) IsShuttingDown() bool {
|
||||
return s.gracefulServer.isShutdown()
|
||||
}
|
||||
|
||||
// Wait blocks until shutdown is complete.
|
||||
func (s *serverInstance) Wait() {
|
||||
s.gracefulServer.wait()
|
||||
}
|
||||
399
pkg/server/manager_test.go
Normal file
399
pkg/server/manager_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// getFreePort asks the kernel for a free open port that is ready to use.
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func TestServerManagerLifecycle(t *testing.T) {
|
||||
// Initialize logger for test output
|
||||
logger.Init(true)
|
||||
|
||||
// Create a new server manager
|
||||
sm := NewManager()
|
||||
|
||||
// Define a simple test handler
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
// Get a free port for the server to listen on to avoid conflicts
|
||||
testPort := getFreePort(t)
|
||||
|
||||
// Add a new server configuration
|
||||
serverConfig := Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: testHandler,
|
||||
}
|
||||
instance, err := sm.Add(serverConfig)
|
||||
require.NoError(t, err, "should be able to add a new server")
|
||||
require.NotNil(t, instance, "added instance should not be nil")
|
||||
|
||||
// --- Test StartAll ---
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err, "StartAll should not return an error")
|
||||
|
||||
// Give the server a moment to start up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// --- Verify Server is Running ---
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
require.NoError(t, err, "should be able to make a request to the running server")
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
assert.Equal(t, "Hello, World!", string(body), "response body should match expected value")
|
||||
|
||||
// --- Test Get ---
|
||||
retrievedInstance, err := sm.Get("TestServer")
|
||||
require.NoError(t, err, "should be able to get server by name")
|
||||
assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same")
|
||||
|
||||
// --- Test List ---
|
||||
instanceList := sm.List()
|
||||
require.Len(t, instanceList, 1, "list should contain one instance")
|
||||
assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same")
|
||||
|
||||
// --- Test StopAll ---
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err, "StopAll should not return an error")
|
||||
|
||||
// Give the server a moment to shut down
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// --- Verify Server is Stopped ---
|
||||
_, err = client.Get(url)
|
||||
require.Error(t, err, "should not be able to make a request to a stopped server")
|
||||
|
||||
// --- Test Remove ---
|
||||
err = sm.Remove("TestServer")
|
||||
require.NoError(t, err, "should be able to remove a server")
|
||||
|
||||
_, err = sm.Get("TestServer")
|
||||
require.Error(t, err, "should not be able to get a removed server")
|
||||
}
|
||||
|
||||
func TestManagerErrorCases(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
testPort := getFreePort(t)
|
||||
|
||||
// --- Test Add Duplicate Name ---
|
||||
config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()}
|
||||
_, err := sm.Add(config1)
|
||||
require.NoError(t, err)
|
||||
|
||||
config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()}
|
||||
_, err = sm.Add(config2)
|
||||
require.Error(t, err, "should not be able to add a server with a duplicate name")
|
||||
|
||||
// --- Test Get Non-existent ---
|
||||
_, err = sm.Get("NonExistent")
|
||||
require.Error(t, err, "should get an error for a non-existent server")
|
||||
|
||||
// --- Test Add with Nil Handler ---
|
||||
config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil}
|
||||
_, err = sm.Add(config3)
|
||||
require.Error(t, err, "should not be able to add a server with a nil handler")
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
requestsHandled := 0
|
||||
var requestsMu sync.Mutex
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestsMu.Lock()
|
||||
requestsHandled++
|
||||
requestsMu.Unlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
instance, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: handler,
|
||||
DrainTimeout: 2 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Send some concurrent requests
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait a bit for requests to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check in-flight requests
|
||||
inFlight := instance.InFlightRequests()
|
||||
assert.Greater(t, inFlight, int64(0), "Should have in-flight requests")
|
||||
|
||||
// Stop the server
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for all requests to complete
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests were handled
|
||||
requestsMu.Lock()
|
||||
handled := requestsHandled
|
||||
requestsMu.Unlock()
|
||||
assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled")
|
||||
|
||||
// Verify no in-flight requests
|
||||
assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown")
|
||||
}
|
||||
|
||||
func TestHealthAndReadinessEndpoints(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
testPort := getFreePort(t)
|
||||
|
||||
instance, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: mux,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add health and readiness endpoints
|
||||
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
baseURL := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
|
||||
// Test health endpoint
|
||||
resp, err := client.Get(baseURL + "/health")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
assert.Contains(t, string(body), "healthy")
|
||||
|
||||
// Test readiness endpoint
|
||||
resp, err = client.Get(baseURL + "/ready")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ = io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
assert.Contains(t, string(body), "ready")
|
||||
assert.Contains(t, string(body), "in_flight_requests")
|
||||
|
||||
// Stop the server
|
||||
sm.StopAll()
|
||||
}
|
||||
|
||||
func TestRequestRejectionDuringShutdown(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
_, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: handler,
|
||||
DrainTimeout: 1 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Start shutdown in background
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
sm.StopAll()
|
||||
}()
|
||||
|
||||
// Give shutdown time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to make a request after shutdown started
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||
resp, err := client.Get(url)
|
||||
|
||||
// The request should either fail (connection refused) or get 503
|
||||
if err == nil {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown")
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownCallbacks(t *testing.T) {
|
||||
logger.Init(true)
|
||||
sm := NewManager()
|
||||
|
||||
callbackExecuted := false
|
||||
var callbackMu sync.Mutex
|
||||
|
||||
sm.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
callbackMu.Lock()
|
||||
callbackExecuted = true
|
||||
callbackMu.Unlock()
|
||||
return nil
|
||||
})
|
||||
|
||||
testPort := getFreePort(t)
|
||||
_, err := sm.Add(Config{
|
||||
Name: "TestServer",
|
||||
Host: "localhost",
|
||||
Port: testPort,
|
||||
Handler: http.NewServeMux(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sm.StartAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = sm.StopAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
callbackMu.Lock()
|
||||
executed := callbackExecuted
|
||||
callbackMu.Unlock()
|
||||
|
||||
assert.True(t, executed, "Shutdown callback should have been executed")
|
||||
}
|
||||
|
||||
func TestSelfSignedSSLCertificateReuse(t *testing.T) {
|
||||
logger.Init(true)
|
||||
|
||||
// Get expected cert directory location
|
||||
cacheDir, err := os.UserCacheDir()
|
||||
require.NoError(t, err)
|
||||
certDir := filepath.Join(cacheDir, "resolvespec", "certs")
|
||||
|
||||
host := "localhost"
|
||||
certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host))
|
||||
keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host))
|
||||
|
||||
// Clean up any existing cert files from previous tests
|
||||
os.Remove(certFile)
|
||||
os.Remove(keyFile)
|
||||
|
||||
// First server creation - should generate new certificates
|
||||
sm1 := NewManager()
|
||||
testPort1 := getFreePort(t)
|
||||
_, err = sm1.Add(Config{
|
||||
Name: "SSLTestServer1",
|
||||
Host: host,
|
||||
Port: testPort1,
|
||||
Handler: http.NewServeMux(),
|
||||
SelfSignedSSL: true,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify certificates were created
|
||||
_, err = os.Stat(certFile)
|
||||
require.NoError(t, err, "certificate file should exist after first creation")
|
||||
_, err = os.Stat(keyFile)
|
||||
require.NoError(t, err, "key file should exist after first creation")
|
||||
|
||||
// Get modification time of cert file
|
||||
info1, err := os.Stat(certFile)
|
||||
require.NoError(t, err)
|
||||
modTime1 := info1.ModTime()
|
||||
|
||||
// Wait a bit to ensure different modification times
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second server creation - should reuse existing certificates
|
||||
sm2 := NewManager()
|
||||
testPort2 := getFreePort(t)
|
||||
_, err = sm2.Add(Config{
|
||||
Name: "SSLTestServer2",
|
||||
Host: host,
|
||||
Port: testPort2,
|
||||
Handler: http.NewServeMux(),
|
||||
SelfSignedSSL: true,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get modification time of cert file after second creation
|
||||
info2, err := os.Stat(certFile)
|
||||
require.NoError(t, err)
|
||||
modTime2 := info2.ModTime()
|
||||
|
||||
// Verify the certificate was reused (same modification time)
|
||||
assert.Equal(t, modTime1, modTime2, "certificate should be reused, not regenerated")
|
||||
|
||||
// Clean up
|
||||
sm1.StopAll()
|
||||
sm2.StopAll()
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// GracefulServer wraps http.Server with graceful shutdown capabilities
|
||||
type GracefulServer struct {
|
||||
server *http.Server
|
||||
shutdownTimeout time.Duration
|
||||
drainTimeout time.Duration
|
||||
inFlightRequests atomic.Int64
|
||||
isShuttingDown atomic.Bool
|
||||
shutdownOnce sync.Once
|
||||
shutdownComplete chan struct{}
|
||||
}
|
||||
|
||||
// Config holds configuration for the graceful server
|
||||
type Config struct {
|
||||
// Addr is the server address (e.g., ":8080")
|
||||
Addr string
|
||||
|
||||
// Handler is the HTTP handler
|
||||
Handler http.Handler
|
||||
|
||||
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||
// Default: 30 seconds
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||
// before forcing shutdown. Default: 25 seconds
|
||||
DrainTimeout time.Duration
|
||||
|
||||
// ReadTimeout is the maximum duration for reading the entire request
|
||||
ReadTimeout time.Duration
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewGracefulServer creates a new graceful server
|
||||
func NewGracefulServer(config Config) *GracefulServer {
|
||||
if config.ShutdownTimeout == 0 {
|
||||
config.ShutdownTimeout = 30 * time.Second
|
||||
}
|
||||
if config.DrainTimeout == 0 {
|
||||
config.DrainTimeout = 25 * time.Second
|
||||
}
|
||||
if config.ReadTimeout == 0 {
|
||||
config.ReadTimeout = 10 * time.Second
|
||||
}
|
||||
if config.WriteTimeout == 0 {
|
||||
config.WriteTimeout = 10 * time.Second
|
||||
}
|
||||
if config.IdleTimeout == 0 {
|
||||
config.IdleTimeout = 120 * time.Second
|
||||
}
|
||||
|
||||
gs := &GracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: config.Addr,
|
||||
Handler: config.Handler,
|
||||
ReadTimeout: config.ReadTimeout,
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
},
|
||||
shutdownTimeout: config.ShutdownTimeout,
|
||||
drainTimeout: config.DrainTimeout,
|
||||
shutdownComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
return gs
|
||||
}
|
||||
|
||||
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if shutting down
|
||||
if gs.isShuttingDown.Load() {
|
||||
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment in-flight counter
|
||||
gs.inFlightRequests.Add(1)
|
||||
defer gs.inFlightRequests.Add(-1)
|
||||
|
||||
// Serve the request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ListenAndServe starts the server and handles graceful shutdown
|
||||
func (gs *GracefulServer) ListenAndServe() error {
|
||||
// Wrap handler with request tracking
|
||||
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
|
||||
|
||||
// Start server in goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
logger.Info("Starting server on %s", gs.server.Addr)
|
||||
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
serverErr <- err
|
||||
}
|
||||
close(serverErr)
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
select {
|
||||
case err := <-serverErr:
|
||||
return err
|
||||
case sig := <-sigChan:
|
||||
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||
return gs.Shutdown(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown performs graceful shutdown with request draining
|
||||
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
|
||||
var shutdownErr error
|
||||
|
||||
gs.shutdownOnce.Do(func() {
|
||||
logger.Info("Starting graceful shutdown...")
|
||||
|
||||
// Mark as shutting down (new requests will be rejected)
|
||||
gs.isShuttingDown.Store(true)
|
||||
|
||||
// Create context with timeout
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Wait for in-flight requests to complete (with drain timeout)
|
||||
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
shutdownErr = gs.drainRequests(drainCtx)
|
||||
if shutdownErr != nil {
|
||||
logger.Error("Error draining requests: %v", shutdownErr)
|
||||
}
|
||||
|
||||
// Shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("Error shutting down server: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Graceful shutdown complete")
|
||||
close(gs.shutdownComplete)
|
||||
})
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
// drainRequests waits for in-flight requests to complete
|
||||
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
inFlight := gs.inFlightRequests.Load()
|
||||
|
||||
if inFlight == 0 {
|
||||
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||
case <-ticker.C:
|
||||
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InFlightRequests returns the current number of in-flight requests
|
||||
func (gs *GracefulServer) InFlightRequests() int64 {
|
||||
return gs.inFlightRequests.Load()
|
||||
}
|
||||
|
||||
// IsShuttingDown returns true if the server is shutting down
|
||||
func (gs *GracefulServer) IsShuttingDown() bool {
|
||||
return gs.isShuttingDown.Load()
|
||||
}
|
||||
|
||||
// Wait blocks until shutdown is complete
|
||||
func (gs *GracefulServer) Wait() {
|
||||
<-gs.shutdownComplete
|
||||
}
|
||||
|
||||
// HealthCheckHandler returns a handler that responds to health checks
|
||||
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
|
||||
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.IsShuttingDown() {
|
||||
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write. %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadinessHandler returns a handler for readiness checks
|
||||
// Includes in-flight request count
|
||||
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if gs.IsShuttingDown() {
|
||||
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
inFlight := gs.InFlightRequests()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
// ShutdownCallback is a function called during shutdown
|
||||
type ShutdownCallback func(context.Context) error
|
||||
|
||||
// shutdownCallbacks stores registered shutdown callbacks
|
||||
var (
|
||||
shutdownCallbacks []ShutdownCallback
|
||||
shutdownCallbacksMu sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterShutdownCallback registers a callback to be called during shutdown
|
||||
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||
func RegisterShutdownCallback(cb ShutdownCallback) {
|
||||
shutdownCallbacksMu.Lock()
|
||||
defer shutdownCallbacksMu.Unlock()
|
||||
shutdownCallbacks = append(shutdownCallbacks, cb)
|
||||
}
|
||||
|
||||
// executeShutdownCallbacks runs all registered shutdown callbacks
|
||||
func executeShutdownCallbacks(ctx context.Context) error {
|
||||
shutdownCallbacksMu.Lock()
|
||||
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
|
||||
copy(callbacks, shutdownCallbacks)
|
||||
shutdownCallbacksMu.Unlock()
|
||||
|
||||
var errors []error
|
||||
for i, cb := range callbacks {
|
||||
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
|
||||
if err := cb(ctx); err != nil {
|
||||
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("shutdown callbacks failed: %v", errors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
|
||||
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
|
||||
// Execute callbacks first
|
||||
if err := executeShutdownCallbacks(ctx); err != nil {
|
||||
logger.Error("Error executing shutdown callbacks: %v", err)
|
||||
}
|
||||
|
||||
// Then shutdown the server
|
||||
return gs.Shutdown(ctx)
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGracefulServerTrackRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Start some requests
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait a bit for requests to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Check in-flight count
|
||||
inFlight := srv.InFlightRequests()
|
||||
if inFlight == 0 {
|
||||
t.Error("Should have in-flight requests")
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
wg.Wait()
|
||||
|
||||
// Check that counter is back to zero
|
||||
inFlight = srv.InFlightRequests()
|
||||
if inFlight != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", inFlight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
})
|
||||
|
||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
||||
|
||||
// Mark as shutting down
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
// Try to make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// Should get 503
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.HealthCheckHandler()
|
||||
|
||||
// Healthy
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if w.Body.String() != `{"status":"healthy"}` {
|
||||
t.Errorf("Unexpected body: %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
// Shutting down
|
||||
t.Run("ShuttingDown", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadinessHandler(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
})
|
||||
|
||||
handler := srv.ReadinessHandler()
|
||||
|
||||
// Ready with no in-flight requests
|
||||
t.Run("Ready", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if body != `{"ready":true,"in_flight_requests":0}` {
|
||||
t.Errorf("Unexpected body: %s", body)
|
||||
}
|
||||
})
|
||||
|
||||
// Not ready during shutdown
|
||||
t.Run("NotReady", func(t *testing.T) {
|
||||
srv.isShuttingDown.Store(true)
|
||||
|
||||
req := httptest.NewRequest("GET", "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShutdownCallbacks(t *testing.T) {
|
||||
callbackExecuted := false
|
||||
|
||||
RegisterShutdownCallback(func(ctx context.Context) error {
|
||||
callbackExecuted = true
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := executeShutdownCallbacks(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("executeShutdownCallbacks() error = %v", err)
|
||||
}
|
||||
|
||||
if !callbackExecuted {
|
||||
t.Error("Shutdown callback was not executed")
|
||||
}
|
||||
|
||||
// Reset for other tests
|
||||
shutdownCallbacks = nil
|
||||
}
|
||||
|
||||
func TestDrainRequests(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 1 * time.Second,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests
|
||||
srv.inFlightRequests.Add(3)
|
||||
|
||||
// Start draining in background
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Simulate requests completing
|
||||
srv.inFlightRequests.Add(-3)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("drainRequests() error = %v", err)
|
||||
}
|
||||
|
||||
if srv.InFlightRequests() != 0 {
|
||||
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainRequestsTimeout(t *testing.T) {
|
||||
srv := NewGracefulServer(Config{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
DrainTimeout: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
// Simulate in-flight requests that don't complete
|
||||
srv.inFlightRequests.Add(5)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := srv.drainRequests(ctx)
|
||||
if err == nil {
|
||||
t.Error("drainRequests() should timeout with error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
srv.inFlightRequests.Add(-5)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
|
||||
// Including here for completeness of server tests
|
||||
}
|
||||
439
pkg/server/staticweb/README.md
Normal file
439
pkg/server/staticweb/README.md
Normal file
@@ -0,0 +1,439 @@
|
||||
# StaticWeb - Interface-Driven Static File Server
|
||||
|
||||
StaticWeb is a flexible, interface-driven Go package for serving static files over HTTP. It supports multiple filesystem backends (local, zip, embedded) and provides pluggable policies for caching, MIME types, and fallback strategies.
|
||||
|
||||
## Features
|
||||
|
||||
- **Router-Agnostic**: Works with any HTTP router through standard `http.Handler`
|
||||
- **Multiple Filesystem Providers**: Local directories, zip files, embedded filesystems
|
||||
- **Pluggable Policies**: Customizable cache, MIME type, and fallback strategies
|
||||
- **Thread-Safe**: Safe for concurrent use
|
||||
- **Resource Management**: Proper lifecycle management with `Close()` methods
|
||||
- **Extensible**: Easy to add new providers and policies
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/bitechdev/ResolveSpec/pkg/server/staticweb
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
Serve files from a local directory:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
|
||||
|
||||
// Create service
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Mount a local directory
|
||||
provider, _ := staticweb.LocalProvider("./public")
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
// Use with any router
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### Single Page Application (SPA)
|
||||
|
||||
Serve an SPA with HTML fallback routing:
|
||||
|
||||
```go
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
provider, _ := staticweb.LocalProvider("./dist")
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: provider,
|
||||
FallbackStrategy: staticweb.HTMLFallback("index.html"),
|
||||
})
|
||||
|
||||
// API routes take precedence (registered first)
|
||||
router.HandleFunc("/api/users", usersHandler)
|
||||
router.HandleFunc("/api/posts", postsHandler)
|
||||
|
||||
// Static files handle all other routes
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
## Filesystem Providers
|
||||
|
||||
### Local Directory
|
||||
|
||||
Serve files from a local filesystem directory:
|
||||
|
||||
```go
|
||||
provider, err := staticweb.LocalProvider("/var/www/static")
|
||||
```
|
||||
|
||||
### Zip File
|
||||
|
||||
Serve files from a zip archive:
|
||||
|
||||
```go
|
||||
provider, err := staticweb.ZipProvider("./static.zip")
|
||||
```
|
||||
|
||||
### Embedded Filesystem
|
||||
|
||||
Serve files from Go's embedded filesystem:
|
||||
|
||||
```go
|
||||
//go:embed assets
|
||||
var assets embed.FS
|
||||
|
||||
// Direct embedded FS
|
||||
provider, err := staticweb.EmbedProvider(&assets, "")
|
||||
|
||||
// Or from a zip file within embedded FS
|
||||
provider, err := staticweb.EmbedProvider(&assets, "assets.zip")
|
||||
```
|
||||
|
||||
## Cache Policies
|
||||
|
||||
### Simple Cache
|
||||
|
||||
Single TTL for all files:
|
||||
|
||||
```go
|
||||
cachePolicy := staticweb.SimpleCache(3600) // 1 hour
|
||||
```
|
||||
|
||||
### Extension-Based Cache
|
||||
|
||||
Different TTLs per file type:
|
||||
|
||||
```go
|
||||
rules := map[string]int{
|
||||
".html": 3600, // 1 hour
|
||||
".js": 86400, // 1 day
|
||||
".css": 86400, // 1 day
|
||||
".png": 604800, // 1 week
|
||||
}
|
||||
|
||||
cachePolicy := staticweb.ExtensionCache(rules, 3600) // default 1 hour
|
||||
```
|
||||
|
||||
### No Cache
|
||||
|
||||
Disable caching entirely:
|
||||
|
||||
```go
|
||||
cachePolicy := staticweb.NoCache()
|
||||
```
|
||||
|
||||
## Fallback Strategies
|
||||
|
||||
### HTML Fallback
|
||||
|
||||
Serve index.html for non-asset requests (SPA routing):
|
||||
|
||||
```go
|
||||
fallback := staticweb.HTMLFallback("index.html")
|
||||
```
|
||||
|
||||
### Extension-Based Fallback
|
||||
|
||||
Skip fallback for known static assets:
|
||||
|
||||
```go
|
||||
fallback := staticweb.DefaultExtensionFallback("index.html")
|
||||
```
|
||||
|
||||
Custom extensions:
|
||||
|
||||
```go
|
||||
staticExts := []string{".js", ".css", ".png", ".jpg"}
|
||||
fallback := staticweb.ExtensionFallback(staticExts, "index.html")
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Service Configuration
|
||||
|
||||
```go
|
||||
config := &staticweb.ServiceConfig{
|
||||
DefaultCacheTime: 3600,
|
||||
DefaultMIMETypes: map[string]string{
|
||||
".webp": "image/webp",
|
||||
".wasm": "application/wasm",
|
||||
},
|
||||
}
|
||||
|
||||
service := staticweb.NewService(config)
|
||||
```
|
||||
|
||||
### Mount Configuration
|
||||
|
||||
```go
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
CachePolicy: cachePolicy, // Optional
|
||||
MIMEResolver: mimeResolver, // Optional
|
||||
FallbackStrategy: fallbackStrategy, // Optional
|
||||
})
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Multiple Mount Points
|
||||
|
||||
Serve different directories at different URL prefixes with different policies:
|
||||
|
||||
```go
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Long-lived assets
|
||||
assetsProvider, _ := staticweb.LocalProvider("./assets")
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/assets",
|
||||
Provider: assetsProvider,
|
||||
CachePolicy: staticweb.SimpleCache(604800), // 1 week
|
||||
})
|
||||
|
||||
// Frequently updated HTML
|
||||
htmlProvider, _ := staticweb.LocalProvider("./public")
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: htmlProvider,
|
||||
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
|
||||
})
|
||||
```
|
||||
|
||||
### Custom MIME Types
|
||||
|
||||
```go
|
||||
mimeResolver := staticweb.DefaultMIMEResolver()
|
||||
mimeResolver.RegisterMIMEType(".webp", "image/webp")
|
||||
mimeResolver.RegisterMIMEType(".wasm", "application/wasm")
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
MIMEResolver: mimeResolver,
|
||||
})
|
||||
```
|
||||
|
||||
### Resource Cleanup
|
||||
|
||||
Always close the service when done:
|
||||
|
||||
```go
|
||||
service := staticweb.NewService(nil)
|
||||
defer service.Close()
|
||||
|
||||
// ... mount and use service ...
|
||||
```
|
||||
|
||||
Or unmount individual mount points:
|
||||
|
||||
```go
|
||||
service.Unmount("/static")
|
||||
```
|
||||
|
||||
### Reloading/Refreshing Content
|
||||
|
||||
Reload providers to pick up changes from the underlying filesystem. This is particularly useful for zip files in development:
|
||||
|
||||
```go
|
||||
// When zip file or directory contents change
|
||||
err := service.Reload()
|
||||
if err != nil {
|
||||
log.Printf("Failed to reload: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
Providers that support reloading:
|
||||
- **ZipFSProvider**: Reopens the zip file to pick up changes
|
||||
- **LocalFSProvider**: Refreshes the directory view (automatically picks up changes)
|
||||
- **EmbedFSProvider**: Not reloadable (embedded at compile time)
|
||||
|
||||
You can also reload individual providers:
|
||||
|
||||
```go
|
||||
if reloadable, ok := provider.(staticweb.ReloadableProvider); ok {
|
||||
err := reloadable.Reload()
|
||||
if err != nil {
|
||||
log.Printf("Failed to reload: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Development Workflow Example:**
|
||||
|
||||
```go
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
provider, _ := staticweb.ZipProvider("./dist.zip")
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/app",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
// In development, reload when dist.zip is rebuilt
|
||||
go func() {
|
||||
watcher := fsnotify.NewWatcher()
|
||||
watcher.Add("./dist.zip")
|
||||
|
||||
for range watcher.Events {
|
||||
log.Println("Reloading static files...")
|
||||
if err := service.Reload(); err != nil {
|
||||
log.Printf("Reload failed: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
## Router Integration
|
||||
|
||||
### Gorilla Mux
|
||||
|
||||
```go
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users", usersHandler)
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### Standard http.ServeMux
|
||||
|
||||
```go
|
||||
http.Handle("/api/", apiHandler)
|
||||
http.Handle("/", service.Handler())
|
||||
```
|
||||
|
||||
### BunRouter
|
||||
|
||||
```go
|
||||
router.GET("/api/users", usersHandler)
|
||||
router.GET("/*path", bunrouter.HTTPHandlerFunc(service.Handler()))
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Interfaces
|
||||
|
||||
#### FileSystemProvider
|
||||
|
||||
Abstracts the source of files:
|
||||
|
||||
```go
|
||||
type FileSystemProvider interface {
|
||||
Open(name string) (fs.File, error)
|
||||
Close() error
|
||||
Type() string
|
||||
}
|
||||
```
|
||||
|
||||
Implementations:
|
||||
- `LocalFSProvider` - Local directories
|
||||
- `ZipFSProvider` - Zip archives
|
||||
- `EmbedFSProvider` - Embedded filesystems
|
||||
|
||||
#### CachePolicy
|
||||
|
||||
Defines caching behavior:
|
||||
|
||||
```go
|
||||
type CachePolicy interface {
|
||||
GetCacheTime(path string) int
|
||||
GetCacheHeaders(path string) map[string]string
|
||||
}
|
||||
```
|
||||
|
||||
Implementations:
|
||||
- `SimpleCachePolicy` - Single TTL
|
||||
- `ExtensionBasedCachePolicy` - Per-extension TTL
|
||||
- `NoCachePolicy` - Disable caching
|
||||
|
||||
#### FallbackStrategy
|
||||
|
||||
Handles missing files:
|
||||
|
||||
```go
|
||||
type FallbackStrategy interface {
|
||||
ShouldFallback(path string) bool
|
||||
GetFallbackPath(path string) string
|
||||
}
|
||||
```
|
||||
|
||||
Implementations:
|
||||
- `NoFallback` - Return 404
|
||||
- `HTMLFallbackStrategy` - SPA routing
|
||||
- `ExtensionBasedFallback` - Skip known assets
|
||||
|
||||
#### MIMETypeResolver
|
||||
|
||||
Determines Content-Type:
|
||||
|
||||
```go
|
||||
type MIMETypeResolver interface {
|
||||
GetMIMEType(path string) string
|
||||
RegisterMIMEType(extension, mimeType string)
|
||||
}
|
||||
```
|
||||
|
||||
Implementations:
|
||||
- `DefaultMIMEResolver` - Common web types
|
||||
- `ConfigurableMIMEResolver` - Custom mappings
|
||||
|
||||
## Testing
|
||||
|
||||
### Mock Providers
|
||||
|
||||
```go
|
||||
import staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
|
||||
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>test</html>"),
|
||||
"app.js": []byte("console.log('test')"),
|
||||
})
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: provider,
|
||||
})
|
||||
```
|
||||
|
||||
### Test Helpers
|
||||
|
||||
```go
|
||||
req := httptest.NewRequest("GET", "/index.html", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
service.Handler().ServeHTTP(rec, req)
|
||||
|
||||
// Assert response
|
||||
assert.Equal(t, 200, rec.Code)
|
||||
```
|
||||
|
||||
## Future Features
|
||||
|
||||
The interface-driven design allows for easy extensibility:
|
||||
|
||||
### Planned Providers
|
||||
|
||||
- **HTTPFSProvider**: Fetch files from remote HTTP servers with local caching
|
||||
- **S3FSProvider**: Serve files from S3-compatible storage
|
||||
- **CompositeProvider**: Fallback chain across multiple providers
|
||||
- **MemoryProvider**: In-memory filesystem for testing
|
||||
|
||||
### Planned Policies
|
||||
|
||||
- **TimedCachePolicy**: Different cache times by time of day
|
||||
- **ConditionalCachePolicy**: Smart cache based on file size/type
|
||||
- **RegexFallbackStrategy**: Pattern-based routing
|
||||
|
||||
## License
|
||||
|
||||
See the main repository for license information.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! The interface-driven design makes it easy to add new providers and policies without modifying existing code.
|
||||
99
pkg/server/staticweb/config.go
Normal file
99
pkg/server/staticweb/config.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package staticweb
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb/policies"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb/providers"
|
||||
)
|
||||
|
||||
// ServiceConfig configures the static file service.
|
||||
type ServiceConfig struct {
|
||||
// DefaultCacheTime is the default cache duration in seconds.
|
||||
// Used when a mount point doesn't specify a custom CachePolicy.
|
||||
// Default: 172800 (48 hours)
|
||||
DefaultCacheTime int
|
||||
|
||||
// DefaultMIMETypes is a map of file extensions to MIME types.
|
||||
// These are added to the default MIME resolver.
|
||||
// Extensions should include the leading dot (e.g., ".webp").
|
||||
DefaultMIMETypes map[string]string
|
||||
}
|
||||
|
||||
// DefaultServiceConfig returns a ServiceConfig with sensible defaults.
|
||||
func DefaultServiceConfig() *ServiceConfig {
|
||||
return &ServiceConfig{
|
||||
DefaultCacheTime: 172800, // 48 hours
|
||||
DefaultMIMETypes: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the ServiceConfig is valid.
|
||||
func (c *ServiceConfig) Validate() error {
|
||||
if c.DefaultCacheTime < 0 {
|
||||
return fmt.Errorf("DefaultCacheTime cannot be negative")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper constructor functions for providers
|
||||
|
||||
// LocalProvider creates a FileSystemProvider for a local directory.
|
||||
func LocalProvider(path string) (FileSystemProvider, error) {
|
||||
return providers.NewLocalFSProvider(path)
|
||||
}
|
||||
|
||||
// ZipProvider creates a FileSystemProvider for a zip file.
|
||||
func ZipProvider(zipPath string) (FileSystemProvider, error) {
|
||||
return providers.NewZipFSProvider(zipPath)
|
||||
}
|
||||
|
||||
// EmbedProvider creates a FileSystemProvider for an embedded filesystem.
|
||||
// If zipFile is empty, the embedded FS is used directly.
|
||||
// If zipFile is specified, it's treated as a path to a zip file within the embedded FS.
|
||||
// The embedFS parameter can be any fs.FS, but is typically *embed.FS.
|
||||
func EmbedProvider(embedFS *embed.FS, zipFile string) (FileSystemProvider, error) {
|
||||
return providers.NewEmbedFSProvider(embedFS, zipFile)
|
||||
}
|
||||
|
||||
// Policy constructor functions
|
||||
|
||||
// SimpleCache creates a simple cache policy with the given TTL in seconds.
|
||||
func SimpleCache(seconds int) CachePolicy {
|
||||
return policies.NewSimpleCachePolicy(seconds)
|
||||
}
|
||||
|
||||
// ExtensionCache creates an extension-based cache policy.
|
||||
// rules maps file extensions (with leading dot) to cache times in seconds.
|
||||
// defaultTime is used for files that don't match any rule.
|
||||
func ExtensionCache(rules map[string]int, defaultTime int) CachePolicy {
|
||||
return policies.NewExtensionBasedCachePolicy(rules, defaultTime)
|
||||
}
|
||||
|
||||
// NoCache creates a cache policy that disables all caching.
|
||||
func NoCache() CachePolicy {
|
||||
return policies.NewNoCachePolicy()
|
||||
}
|
||||
|
||||
// HTMLFallback creates a fallback strategy for SPAs that serves the given index file.
|
||||
func HTMLFallback(indexFile string) FallbackStrategy {
|
||||
return policies.NewHTMLFallbackStrategy(indexFile)
|
||||
}
|
||||
|
||||
// ExtensionFallback creates an extension-based fallback strategy.
|
||||
// staticExtensions is a list of file extensions that should NOT use fallback.
|
||||
// fallbackPath is the file to serve when fallback is triggered.
|
||||
func ExtensionFallback(staticExtensions []string, fallbackPath string) FallbackStrategy {
|
||||
return policies.NewExtensionBasedFallback(staticExtensions, fallbackPath)
|
||||
}
|
||||
|
||||
// DefaultExtensionFallback creates an extension-based fallback with common web asset extensions.
|
||||
func DefaultExtensionFallback(fallbackPath string) FallbackStrategy {
|
||||
return policies.NewDefaultExtensionBasedFallback(fallbackPath)
|
||||
}
|
||||
|
||||
// DefaultMIMEResolver creates a MIME resolver with common web file types.
|
||||
func DefaultMIMEResolver() MIMETypeResolver {
|
||||
return policies.NewDefaultMIMEResolver()
|
||||
}
|
||||
60
pkg/server/staticweb/example_reload_test.go
Normal file
60
pkg/server/staticweb/example_reload_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package staticweb_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
|
||||
staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
|
||||
)
|
||||
|
||||
// Example_reload demonstrates reloading content when files change.
|
||||
func Example_reload() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Create a provider
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"version.txt": []byte("v1.0.0"),
|
||||
})
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
// Simulate updating the file
|
||||
provider.AddFile("version.txt", []byte("v2.0.0"))
|
||||
|
||||
// Reload to pick up changes (in real usage with zip files)
|
||||
err := service.Reload()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to reload: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("Successfully reloaded static files")
|
||||
}
|
||||
|
||||
// Output: Successfully reloaded static files
|
||||
}
|
||||
|
||||
// Example_reloadZip demonstrates reloading a zip file provider.
|
||||
func Example_reloadZip() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// In production, you would use:
|
||||
// provider, _ := staticweb.ZipProvider("./dist.zip")
|
||||
// For this example, we use a mock
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"app.js": []byte("console.log('v1')"),
|
||||
})
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/app",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
fmt.Println("Serving from zip file")
|
||||
|
||||
// When the zip file is updated, call Reload()
|
||||
// service.Reload()
|
||||
|
||||
// Output: Serving from zip file
|
||||
}
|
||||
138
pkg/server/staticweb/example_test.go
Normal file
138
pkg/server/staticweb/example_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package staticweb_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server/staticweb"
|
||||
staticwebtesting "github.com/bitechdev/ResolveSpec/pkg/server/staticweb/testing"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Example_basic demonstrates serving files from a local directory.
|
||||
func Example_basic() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Using mock provider for example purposes
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>test</html>"),
|
||||
})
|
||||
|
||||
_ = service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
|
||||
fmt.Println("Serving files from ./public at /static")
|
||||
// Output: Serving files from ./public at /static
|
||||
}
|
||||
|
||||
// Example_spa demonstrates an SPA with HTML fallback routing.
|
||||
func Example_spa() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Using mock provider for example purposes
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>app</html>"),
|
||||
})
|
||||
|
||||
_ = service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: provider,
|
||||
FallbackStrategy: staticweb.HTMLFallback("index.html"),
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// API routes take precedence
|
||||
router.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("users"))
|
||||
})
|
||||
|
||||
// Static files handle all other routes
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
|
||||
fmt.Println("SPA with fallback to index.html")
|
||||
// Output: SPA with fallback to index.html
|
||||
}
|
||||
|
||||
// Example_multiple demonstrates multiple mount points with different policies.
|
||||
func Example_multiple() {
|
||||
service := staticweb.NewService(&staticweb.ServiceConfig{
|
||||
DefaultCacheTime: 3600,
|
||||
})
|
||||
|
||||
// Assets with long cache (using mock for example)
|
||||
assetsProvider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"app.js": []byte("console.log('test')"),
|
||||
})
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/assets",
|
||||
Provider: assetsProvider,
|
||||
CachePolicy: staticweb.SimpleCache(604800), // 1 week
|
||||
})
|
||||
|
||||
// HTML with short cache (using mock for example)
|
||||
htmlProvider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>test</html>"),
|
||||
})
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: htmlProvider,
|
||||
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
|
||||
})
|
||||
|
||||
fmt.Println("Multiple mount points configured")
|
||||
// Output: Multiple mount points configured
|
||||
}
|
||||
|
||||
// Example_zip demonstrates serving from a zip file (concept).
|
||||
func Example_zip() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// For actual usage, you would use:
|
||||
// provider, err := staticweb.ZipProvider("./static.zip")
|
||||
// For this example, we use a mock
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"file.txt": []byte("content"),
|
||||
})
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: provider,
|
||||
})
|
||||
|
||||
fmt.Println("Serving from zip file")
|
||||
// Output: Serving from zip file
|
||||
}
|
||||
|
||||
// Example_extensionCache demonstrates extension-based caching.
|
||||
func Example_extensionCache() {
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
// Using mock provider for example purposes
|
||||
provider := staticwebtesting.NewMockProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>test</html>"),
|
||||
"app.js": []byte("console.log('test')"),
|
||||
})
|
||||
|
||||
// Different cache times per file type
|
||||
cacheRules := map[string]int{
|
||||
".html": 3600, // 1 hour
|
||||
".js": 86400, // 1 day
|
||||
".css": 86400, // 1 day
|
||||
".png": 604800, // 1 week
|
||||
}
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: provider,
|
||||
CachePolicy: staticweb.ExtensionCache(cacheRules, 3600), // default 1 hour
|
||||
})
|
||||
|
||||
fmt.Println("Extension-based caching configured")
|
||||
// Output: Extension-based caching configured
|
||||
}
|
||||
130
pkg/server/staticweb/interfaces.go
Normal file
130
pkg/server/staticweb/interfaces.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package staticweb
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// FileSystemProvider abstracts the source of files (local, zip, embedded, future: http, s3)
|
||||
// Implementations must be safe for concurrent use.
|
||||
type FileSystemProvider interface {
|
||||
// Open opens the named file.
|
||||
// The name is always a slash-separated path relative to the filesystem root.
|
||||
Open(name string) (fs.File, error)
|
||||
|
||||
// Close releases any resources held by the provider.
|
||||
// After Close is called, the provider should not be used.
|
||||
Close() error
|
||||
|
||||
// Type returns the provider type (e.g., "local", "zip", "embed", "http", "s3").
|
||||
// This is primarily for debugging and logging purposes.
|
||||
Type() string
|
||||
}
|
||||
|
||||
// ReloadableProvider is an optional interface that providers can implement
|
||||
// to support reloading/refreshing their content.
|
||||
// This is useful for development workflows where the underlying files may change.
|
||||
type ReloadableProvider interface {
|
||||
FileSystemProvider
|
||||
|
||||
// Reload refreshes the provider's content from the underlying source.
|
||||
// For zip files, this reopens the zip archive.
|
||||
// For local directories, this refreshes the filesystem view.
|
||||
// Returns an error if the reload fails.
|
||||
Reload() error
|
||||
}
|
||||
|
||||
// CachePolicy defines how files should be cached by browsers and proxies.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type CachePolicy interface {
|
||||
// GetCacheTime returns the cache duration in seconds for the given path.
|
||||
// A value of 0 means no caching.
|
||||
// A negative value can be used to indicate browser should revalidate.
|
||||
GetCacheTime(path string) int
|
||||
|
||||
// GetCacheHeaders returns additional cache-related HTTP headers for the given path.
|
||||
// Common headers include "Cache-Control", "Expires", "ETag", etc.
|
||||
// Returns nil if no additional headers are needed.
|
||||
GetCacheHeaders(path string) map[string]string
|
||||
}
|
||||
|
||||
// MIMETypeResolver determines the Content-Type for files.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type MIMETypeResolver interface {
|
||||
// GetMIMEType returns the MIME type for the given file path.
|
||||
// Returns empty string if the MIME type cannot be determined.
|
||||
GetMIMEType(path string) string
|
||||
|
||||
// RegisterMIMEType registers a custom MIME type for the given file extension.
|
||||
// The extension should include the leading dot (e.g., ".webp").
|
||||
RegisterMIMEType(extension, mimeType string)
|
||||
}
|
||||
|
||||
// FallbackStrategy handles requests for files that don't exist.
|
||||
// This is commonly used for Single Page Applications (SPAs) that use client-side routing.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type FallbackStrategy interface {
|
||||
// ShouldFallback determines if a fallback should be attempted for the given path.
|
||||
// Returns true if the request should be handled by fallback logic.
|
||||
ShouldFallback(path string) bool
|
||||
|
||||
// GetFallbackPath returns the path to serve instead of the originally requested path.
|
||||
// This is only called if ShouldFallback returns true.
|
||||
GetFallbackPath(path string) string
|
||||
}
|
||||
|
||||
// MountConfig configures a single mount point.
|
||||
// A mount point connects a URL prefix to a filesystem provider with optional policies.
|
||||
type MountConfig struct {
|
||||
// URLPrefix is the URL path prefix where the filesystem should be mounted.
|
||||
// Must start with "/" (e.g., "/static", "/", "/assets").
|
||||
// Requests starting with this prefix will be handled by this mount point.
|
||||
URLPrefix string
|
||||
|
||||
// Provider is the filesystem provider that supplies the files.
|
||||
// Required.
|
||||
Provider FileSystemProvider
|
||||
|
||||
// CachePolicy determines how files should be cached.
|
||||
// If nil, the service's default cache policy is used.
|
||||
CachePolicy CachePolicy
|
||||
|
||||
// MIMEResolver determines Content-Type headers for files.
|
||||
// If nil, the service's default MIME resolver is used.
|
||||
MIMEResolver MIMETypeResolver
|
||||
|
||||
// FallbackStrategy handles requests for missing files.
|
||||
// If nil, no fallback is performed and 404 responses are returned.
|
||||
FallbackStrategy FallbackStrategy
|
||||
}
|
||||
|
||||
// StaticFileService manages multiple mount points and serves static files.
|
||||
// The service is safe for concurrent use.
|
||||
type StaticFileService interface {
|
||||
// Mount adds a new mount point with the given configuration.
|
||||
// Returns an error if the URLPrefix is already mounted or if the config is invalid.
|
||||
Mount(config MountConfig) error
|
||||
|
||||
// Unmount removes the mount point at the given URL prefix.
|
||||
// Returns an error if no mount point exists at that prefix.
|
||||
// Automatically calls Close() on the provider to release resources.
|
||||
Unmount(urlPrefix string) error
|
||||
|
||||
// ListMounts returns a sorted list of all mounted URL prefixes.
|
||||
ListMounts() []string
|
||||
|
||||
// Reload reinitializes all filesystem providers.
|
||||
// This can be used to pick up changes in the underlying filesystems.
|
||||
// Not all providers may support reloading.
|
||||
Reload() error
|
||||
|
||||
// Close releases all resources held by the service and all mounted providers.
|
||||
// After Close is called, the service should not be used.
|
||||
Close() error
|
||||
|
||||
// Handler returns an http.Handler that serves static files from all mount points.
|
||||
// The handler performs longest-prefix matching to find the appropriate mount point.
|
||||
// If no mount point matches, the handler returns without writing a response,
|
||||
// allowing other handlers (like API routes) to process the request.
|
||||
Handler() http.Handler
|
||||
}
|
||||
235
pkg/server/staticweb/mount.go
Normal file
235
pkg/server/staticweb/mount.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package staticweb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// mountPoint represents a mounted filesystem at a specific URL prefix.
|
||||
type mountPoint struct {
|
||||
urlPrefix string
|
||||
provider FileSystemProvider
|
||||
cachePolicy CachePolicy
|
||||
mimeResolver MIMETypeResolver
|
||||
fallbackStrategy FallbackStrategy
|
||||
fileServer http.Handler
|
||||
}
|
||||
|
||||
// newMountPoint creates a new mount point with the given configuration.
|
||||
func newMountPoint(config MountConfig, defaults *ServiceConfig) (*mountPoint, error) {
|
||||
if config.URLPrefix == "" {
|
||||
return nil, fmt.Errorf("URLPrefix cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(config.URLPrefix, "/") {
|
||||
return nil, fmt.Errorf("URLPrefix must start with /")
|
||||
}
|
||||
|
||||
if config.Provider == nil {
|
||||
return nil, fmt.Errorf("provider cannot be nil")
|
||||
}
|
||||
|
||||
mp := &mountPoint{
|
||||
urlPrefix: config.URLPrefix,
|
||||
provider: config.Provider,
|
||||
cachePolicy: config.CachePolicy,
|
||||
mimeResolver: config.MIMEResolver,
|
||||
fallbackStrategy: config.FallbackStrategy,
|
||||
}
|
||||
|
||||
// Apply defaults if policies are not specified
|
||||
if mp.cachePolicy == nil && defaults != nil {
|
||||
mp.cachePolicy = defaultCachePolicy(defaults.DefaultCacheTime)
|
||||
}
|
||||
|
||||
if mp.mimeResolver == nil {
|
||||
mp.mimeResolver = defaultMIMEResolver()
|
||||
}
|
||||
|
||||
// Create an http.FileServer for serving files
|
||||
mp.fileServer = http.FileServer(http.FS(config.Provider))
|
||||
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
// ServeHTTP handles HTTP requests for files in this mount point.
|
||||
func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the URL prefix to get the file path
|
||||
filePath := strings.TrimPrefix(r.URL.Path, m.urlPrefix)
|
||||
if filePath == "" {
|
||||
filePath = "/"
|
||||
}
|
||||
|
||||
// Clean the path
|
||||
filePath = path.Clean(filePath)
|
||||
|
||||
// Try to open the file
|
||||
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
|
||||
if err != nil {
|
||||
// File doesn't exist - check if we should use fallback
|
||||
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
|
||||
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
|
||||
file, err = m.provider.Open(strings.TrimPrefix(fallbackPath, "/"))
|
||||
if err == nil {
|
||||
// Successfully opened fallback file
|
||||
defer file.Close()
|
||||
m.serveFile(w, r, fallbackPath, file)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No fallback or fallback failed - return 404
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Serve the file
|
||||
m.serveFile(w, r, filePath, file)
|
||||
}
|
||||
|
||||
// serveFile serves a single file with appropriate headers.
|
||||
func (m *mountPoint) serveFile(w http.ResponseWriter, r *http.Request, filePath string, file fs.File) {
|
||||
// Get file info
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// If it's a directory, try to serve index.html
|
||||
if stat.IsDir() {
|
||||
indexPath := path.Join(filePath, "index.html")
|
||||
indexFile, err := m.provider.Open(strings.TrimPrefix(indexPath, "/"))
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
defer indexFile.Close()
|
||||
|
||||
indexStat, err := indexFile.Stat()
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
filePath = indexPath
|
||||
stat = indexStat
|
||||
file = indexFile
|
||||
}
|
||||
|
||||
// Set Content-Type header using MIME resolver
|
||||
if m.mimeResolver != nil {
|
||||
if mimeType := m.mimeResolver.GetMIMEType(filePath); mimeType != "" {
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply cache policy
|
||||
if m.cachePolicy != nil {
|
||||
headers := m.cachePolicy.GetCacheHeaders(filePath)
|
||||
for key, value := range headers {
|
||||
w.Header().Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve the content
|
||||
if seeker, ok := file.(interface {
|
||||
io.ReadSeeker
|
||||
}); ok {
|
||||
http.ServeContent(w, r, stat.Name(), stat.ModTime(), seeker)
|
||||
} else {
|
||||
// If the file doesn't support seeking, we need to read it all into memory
|
||||
data, err := fs.ReadFile(m.provider, strings.TrimPrefix(filePath, "/"))
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.ServeContent(w, r, stat.Name(), stat.ModTime(), strings.NewReader(string(data)))
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases resources held by the mount point.
|
||||
func (m *mountPoint) Close() error {
|
||||
if m.provider != nil {
|
||||
return m.provider.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// defaultCachePolicy creates a default simple cache policy.
|
||||
func defaultCachePolicy(cacheTime int) CachePolicy {
|
||||
// Import the policies package type - we'll need to use the concrete type
|
||||
// For now, create a simple inline implementation
|
||||
return &simpleCachePolicy{cacheTime: cacheTime}
|
||||
}
|
||||
|
||||
// simpleCachePolicy is a simple inline implementation of CachePolicy
|
||||
type simpleCachePolicy struct {
|
||||
cacheTime int
|
||||
}
|
||||
|
||||
func (p *simpleCachePolicy) GetCacheTime(path string) int {
|
||||
return p.cacheTime
|
||||
}
|
||||
|
||||
func (p *simpleCachePolicy) GetCacheHeaders(path string) map[string]string {
|
||||
if p.cacheTime <= 0 {
|
||||
return map[string]string{
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
}
|
||||
}
|
||||
return map[string]string{
|
||||
"Cache-Control": fmt.Sprintf("public, max-age=%d", p.cacheTime),
|
||||
}
|
||||
}
|
||||
|
||||
// defaultMIMEResolver creates a default MIME resolver.
|
||||
func defaultMIMEResolver() MIMETypeResolver {
|
||||
// Import the policies package type - we'll need to use the concrete type
|
||||
// For now, create a simple inline implementation
|
||||
return &simpleMIMEResolver{
|
||||
types: map[string]string{
|
||||
".js": "application/javascript",
|
||||
".mjs": "application/javascript",
|
||||
".cjs": "application/javascript",
|
||||
".css": "text/css",
|
||||
".html": "text/html",
|
||||
".htm": "text/html",
|
||||
".json": "application/json",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".svg": "image/svg+xml",
|
||||
".ico": "image/x-icon",
|
||||
".txt": "text/plain",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// simpleMIMEResolver is a simple inline implementation of MIMETypeResolver
|
||||
type simpleMIMEResolver struct {
|
||||
types map[string]string
|
||||
}
|
||||
|
||||
func (r *simpleMIMEResolver) GetMIMEType(filePath string) string {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
if mimeType, ok := r.types[ext]; ok {
|
||||
return mimeType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *simpleMIMEResolver) RegisterMIMEType(extension, mimeType string) {
|
||||
if !strings.HasPrefix(extension, ".") {
|
||||
extension = "." + extension
|
||||
}
|
||||
r.types[strings.ToLower(extension)] = mimeType
|
||||
}
|
||||
592
pkg/server/staticweb/plan.md
Normal file
592
pkg/server/staticweb/plan.md
Normal file
@@ -0,0 +1,592 @@
|
||||
# StaticWeb Package Interface-Driven Refactoring Plan
|
||||
|
||||
## Overview
|
||||
Refactor `pkg/server/staticweb` to be interface-driven, router-agnostic, and maintainable. This is a breaking change that replaces the existing API with a cleaner design.
|
||||
|
||||
## User Requirements
|
||||
- ✅ Work with any server (not just Gorilla mux)
|
||||
- ✅ Serve static files from zip files or directories
|
||||
- ✅ Support embedded, local filesystems
|
||||
- ✅ Interface-driven and maintainable architecture
|
||||
- ✅ Struct-based configuration
|
||||
- ✅ Breaking changes acceptable
|
||||
- 🔮 Remote HTTP/HTTPS and S3 support (future feature)
|
||||
|
||||
## Design Principles
|
||||
1. **Interface-first**: Define clear interfaces for all abstractions
|
||||
2. **Composition over inheritance**: Combine small, focused components
|
||||
3. **Router-agnostic**: Return standard `http.Handler` for universal compatibility
|
||||
4. **Configurable policies**: Extract hardcoded behavior into pluggable strategies
|
||||
5. **Resource safety**: Proper lifecycle management with `Close()` methods
|
||||
6. **Testability**: Mock-friendly interfaces with clear boundaries
|
||||
7. **Extensibility**: Easy to add new providers (HTTP, S3, etc.) in future
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Interfaces (pkg/server/staticweb/interfaces.go)
|
||||
|
||||
```go
|
||||
// FileSystemProvider abstracts file sources (local, zip, embedded, future: http, s3)
|
||||
type FileSystemProvider interface {
|
||||
Open(name string) (fs.File, error)
|
||||
Close() error
|
||||
Type() string
|
||||
}
|
||||
|
||||
// CachePolicy defines caching behavior
|
||||
type CachePolicy interface {
|
||||
GetCacheTime(path string) int
|
||||
GetCacheHeaders(path string) map[string]string
|
||||
}
|
||||
|
||||
// MIMETypeResolver determines content types
|
||||
type MIMETypeResolver interface {
|
||||
GetMIMEType(path string) string
|
||||
RegisterMIMEType(extension, mimeType string)
|
||||
}
|
||||
|
||||
// FallbackStrategy handles missing files
|
||||
type FallbackStrategy interface {
|
||||
ShouldFallback(path string) bool
|
||||
GetFallbackPath(path string) string
|
||||
}
|
||||
|
||||
// StaticFileService manages mount points
|
||||
type StaticFileService interface {
|
||||
Mount(config MountConfig) error
|
||||
Unmount(urlPrefix string) error
|
||||
ListMounts() []string
|
||||
Reload() error
|
||||
Close() error
|
||||
Handler() http.Handler // Router-agnostic integration
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration (pkg/server/staticweb/config.go)
|
||||
|
||||
```go
|
||||
// MountConfig configures a single mount point
|
||||
type MountConfig struct {
|
||||
URLPrefix string
|
||||
Provider FileSystemProvider
|
||||
CachePolicy CachePolicy // Optional, uses default if nil
|
||||
MIMEResolver MIMETypeResolver // Optional, uses default if nil
|
||||
FallbackStrategy FallbackStrategy // Optional, no fallback if nil
|
||||
}
|
||||
|
||||
// ServiceConfig configures the service
|
||||
type ServiceConfig struct {
|
||||
DefaultCacheTime int // Default: 48 hours
|
||||
DefaultMIMETypes map[string]string // Additional MIME types
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Step 1: Create Core Interfaces
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/interfaces.go` (NEW)
|
||||
|
||||
Define all interfaces listed above. This establishes the contract for all components.
|
||||
|
||||
### Step 2: Implement Default Policies
|
||||
**Directory**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/policies/` (NEW)
|
||||
|
||||
**File**: `policies/cache.go`
|
||||
- `SimpleCachePolicy` - Single TTL for all files
|
||||
- `ExtensionBasedCachePolicy` - Different TTL per file extension
|
||||
- `NoCachePolicy` - Disables caching
|
||||
|
||||
**File**: `policies/mime.go`
|
||||
- `DefaultMIMEResolver` - Standard web MIME types + stdlib
|
||||
- `ConfigurableMIMEResolver` - User-defined mappings
|
||||
- Migrate hardcoded MIME types from `InitMimeTypes()` (lines 60-76)
|
||||
|
||||
**File**: `policies/fallback.go`
|
||||
- `NoFallback` - Returns 404 for missing files
|
||||
- `HTMLFallbackStrategy` - SPA routing (serves index.html)
|
||||
- `ExtensionBasedFallback` - Current behavior (checks extensions)
|
||||
- Migrate logic from `StaticHTMLFallbackHandler()` (lines 241-285)
|
||||
|
||||
### Step 3: Implement FileSystem Providers
|
||||
**Directory**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/providers/` (NEW)
|
||||
|
||||
**File**: `providers/local.go`
|
||||
```go
|
||||
type LocalFSProvider struct {
|
||||
path string
|
||||
fs fs.FS
|
||||
}
|
||||
|
||||
func NewLocalFSProvider(path string) (*LocalFSProvider, error)
|
||||
func (l *LocalFSProvider) Open(name string) (fs.File, error)
|
||||
func (l *LocalFSProvider) Close() error
|
||||
func (l *LocalFSProvider) Type() string
|
||||
```
|
||||
|
||||
**File**: `providers/zip.go`
|
||||
```go
|
||||
type ZipFSProvider struct {
|
||||
zipPath string
|
||||
zipReader *zip.ReadCloser
|
||||
zipFS *zipfs.ZipFS
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewZipFSProvider(zipPath string) (*ZipFSProvider, error)
|
||||
func (z *ZipFSProvider) Open(name string) (fs.File, error)
|
||||
func (z *ZipFSProvider) Close() error
|
||||
func (z *ZipFSProvider) Type() string
|
||||
```
|
||||
- Integrates with existing `pkg/server/zipfs/zipfs.go`
|
||||
- Manages zip file lifecycle properly
|
||||
|
||||
**File**: `providers/embed.go`
|
||||
```go
|
||||
type EmbedFSProvider struct {
|
||||
embedFS *embed.FS
|
||||
zipFile string // Optional: path within embedded FS to zip file
|
||||
zipReader *zip.ReadCloser
|
||||
fs fs.FS
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewEmbedFSProvider(embedFS *embed.FS, zipFile string) (*EmbedFSProvider, error)
|
||||
func (e *EmbedFSProvider) Open(name string) (fs.File, error)
|
||||
func (e *EmbedFSProvider) Close() error
|
||||
func (e *EmbedFSProvider) Type() string
|
||||
```
|
||||
|
||||
### Step 4: Implement Mount Point
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/mount.go` (NEW)
|
||||
|
||||
```go
|
||||
type mountPoint struct {
|
||||
urlPrefix string
|
||||
provider FileSystemProvider
|
||||
cachePolicy CachePolicy
|
||||
mimeResolver MIMETypeResolver
|
||||
fallbackStrategy FallbackStrategy
|
||||
}
|
||||
|
||||
func newMountPoint(config MountConfig, defaults *ServiceConfig) (*mountPoint, error)
|
||||
func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
func (m *mountPoint) Close() error
|
||||
```
|
||||
|
||||
**Key behaviors**:
|
||||
- Strips URL prefix before passing to provider
|
||||
- Applies cache headers via `CachePolicy`
|
||||
- Sets Content-Type via `MIMETypeResolver`
|
||||
- Falls back via `FallbackStrategy` if file not found
|
||||
- Integrates with `http.FileServer()` for actual serving
|
||||
|
||||
### Step 5: Implement Service
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/service.go` (NEW)
|
||||
|
||||
```go
|
||||
type service struct {
|
||||
mounts map[string]*mountPoint // urlPrefix -> mount
|
||||
config *ServiceConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewService(config *ServiceConfig) StaticFileService
|
||||
func (s *service) Mount(config MountConfig) error
|
||||
func (s *service) Unmount(urlPrefix string) error
|
||||
func (s *service) ListMounts() []string
|
||||
func (s *service) Reload() error
|
||||
func (s *service) Close() error
|
||||
func (s *service) Handler() http.Handler
|
||||
```
|
||||
|
||||
**Handler Implementation**:
|
||||
- Performs longest-prefix matching to find mount point
|
||||
- Delegates to mount point's `ServeHTTP()`
|
||||
- Returns silently if no match (allows API routes to handle)
|
||||
- Thread-safe with `sync.RWMutex`
|
||||
|
||||
### Step 6: Create Configuration Helpers
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/config.go` (NEW)
|
||||
|
||||
```go
|
||||
// Default configurations
|
||||
func DefaultServiceConfig() *ServiceConfig
|
||||
func DefaultCachePolicy() CachePolicy
|
||||
func DefaultMIMEResolver() MIMETypeResolver
|
||||
|
||||
// Helper constructors
|
||||
func LocalProvider(path string) FileSystemProvider
|
||||
func ZipProvider(zipPath string) FileSystemProvider
|
||||
func EmbedProvider(embedFS *embed.FS, zipFile string) FileSystemProvider
|
||||
|
||||
// Policy constructors
|
||||
func SimpleCache(seconds int) CachePolicy
|
||||
func ExtensionCache(rules map[string]int) CachePolicy
|
||||
func HTMLFallback(indexFile string) FallbackStrategy
|
||||
func ExtensionFallback(staticExtensions []string) FallbackStrategy
|
||||
```
|
||||
|
||||
### Step 7: Update/Remove Existing Files
|
||||
|
||||
**REMOVE**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/staticweb.go`
|
||||
- All functionality migrated to new interface-based design
|
||||
- No backward compatibility needed per user request
|
||||
|
||||
**KEEP**: `/home/hein/hein/dev/ResolveSpec/pkg/server/zipfs/zipfs.go`
|
||||
- Still used by `ZipFSProvider`
|
||||
- Already implements `fs.FS` interface correctly
|
||||
|
||||
### Step 8: Create Examples and Tests
|
||||
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/example_test.go` (NEW)
|
||||
|
||||
```go
|
||||
func ExampleService_basic() { /* Serve local directory */ }
|
||||
func ExampleService_spa() { /* SPA with fallback */ }
|
||||
func ExampleService_multiple() { /* Multiple mount points */ }
|
||||
func ExampleService_zip() { /* Serve from zip file */ }
|
||||
func ExampleService_embedded() { /* Serve from embedded zip */ }
|
||||
```
|
||||
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/service_test.go` (NEW)
|
||||
- Test mount/unmount operations
|
||||
- Test longest-prefix matching
|
||||
- Test concurrent access
|
||||
- Test resource cleanup
|
||||
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/providers/providers_test.go` (NEW)
|
||||
- Test each provider implementation
|
||||
- Test resource cleanup
|
||||
- Test error handling
|
||||
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/testing/mocks.go` (NEW)
|
||||
- `MockFileSystemProvider` - In-memory file storage
|
||||
- `MockCachePolicy` - Configurable cache behavior
|
||||
- `MockMIMEResolver` - Custom MIME mappings
|
||||
- Test helpers for common scenarios
|
||||
|
||||
### Step 9: Create Documentation
|
||||
|
||||
**File**: `/home/hein/hein/dev/ResolveSpec/pkg/server/staticweb/README.md` (NEW)
|
||||
|
||||
Document:
|
||||
- Quick start examples
|
||||
- Interface overview
|
||||
- Provider implementations
|
||||
- Policy customization
|
||||
- Router integration patterns
|
||||
- Migration guide from old API
|
||||
- Future features roadmap
|
||||
|
||||
## Key Improvements Over Current Implementation
|
||||
|
||||
### 1. Router-Agnostic Design
|
||||
**Before**: Coupled to Gorilla mux via `RegisterRoutes(*mux.Router)`
|
||||
**After**: Returns `http.Handler`, works with any router
|
||||
|
||||
```go
|
||||
// Works with Gorilla Mux
|
||||
muxRouter.PathPrefix("/").Handler(service.Handler())
|
||||
|
||||
// Works with standard http.ServeMux
|
||||
http.Handle("/", service.Handler())
|
||||
|
||||
// Works with any http.Handler-compatible router
|
||||
```
|
||||
|
||||
### 2. Configurable Behaviors
|
||||
**Before**: Hardcoded MIME types, cache times, file extensions
|
||||
**After**: Pluggable policies
|
||||
|
||||
```go
|
||||
// Custom cache per file type
|
||||
cachePolicy := ExtensionCache(map[string]int{
|
||||
".html": 3600, // 1 hour
|
||||
".js": 86400, // 1 day
|
||||
".css": 86400, // 1 day
|
||||
".png": 604800, // 1 week
|
||||
})
|
||||
|
||||
// Custom fallback logic
|
||||
fallback := HTMLFallback("index.html")
|
||||
|
||||
service.Mount(MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: LocalProvider("./dist"),
|
||||
CachePolicy: cachePolicy,
|
||||
FallbackStrategy: fallback,
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Better Resource Management
|
||||
**Before**: Manual zip file cleanup, easy to leak resources
|
||||
**After**: Proper lifecycle with `Close()` on all components
|
||||
|
||||
```go
|
||||
defer service.Close() // Cleans up all providers
|
||||
```
|
||||
|
||||
### 4. Testability
|
||||
**Before**: Hard to test, coupled to filesystem
|
||||
**After**: Mock providers for testing
|
||||
|
||||
```go
|
||||
mockProvider := testing.NewInMemoryProvider(map[string][]byte{
|
||||
"index.html": []byte("<html>test</html>"),
|
||||
})
|
||||
|
||||
service.Mount(MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: mockProvider,
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Extensibility
|
||||
**Before**: Need to modify code to support new file sources
|
||||
**After**: Implement `FileSystemProvider` interface
|
||||
|
||||
```go
|
||||
// Future: Add HTTP provider without changing core code
|
||||
type HTTPFSProvider struct { /* ... */ }
|
||||
func (h *HTTPFSProvider) Open(name string) (fs.File, error) { /* ... */ }
|
||||
func (h *HTTPFSProvider) Close() error { /* ... */ }
|
||||
func (h *HTTPFSProvider) Type() string { return "http" }
|
||||
```
|
||||
|
||||
## Future Features (To Implement Later)
|
||||
|
||||
### Remote HTTP/HTTPS Provider
|
||||
**File**: `providers/http.go` (FUTURE)
|
||||
|
||||
Serve static files from remote HTTP servers with local caching:
|
||||
|
||||
```go
|
||||
type HTTPFSProvider struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
cache LocalCache // Optional disk/memory cache
|
||||
cacheTTL time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Example usage
|
||||
service.Mount(MountConfig{
|
||||
URLPrefix: "/cdn",
|
||||
Provider: HTTPProvider("https://cdn.example.com/assets"),
|
||||
})
|
||||
```
|
||||
|
||||
**Features**:
|
||||
- Fetch files from remote URLs on-demand
|
||||
- Local cache to reduce remote requests
|
||||
- Configurable TTL and cache eviction
|
||||
- HEAD request support for metadata
|
||||
- Retry logic and timeout handling
|
||||
- Support for authentication headers
|
||||
|
||||
### S3-Compatible Provider
|
||||
**File**: `providers/s3.go` (FUTURE)
|
||||
|
||||
Serve static files from S3, MinIO, or S3-compatible storage:
|
||||
|
||||
```go
|
||||
type S3FSProvider struct {
|
||||
bucket string
|
||||
prefix string
|
||||
region string
|
||||
client *s3.Client
|
||||
cache LocalCache
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Example usage
|
||||
service.Mount(MountConfig{
|
||||
URLPrefix: "/media",
|
||||
Provider: S3Provider("my-bucket", "static/", "us-east-1"),
|
||||
})
|
||||
```
|
||||
|
||||
**Features**:
|
||||
- List and fetch objects from S3 buckets
|
||||
- Support for AWS S3, MinIO, DigitalOcean Spaces, etc.
|
||||
- IAM role or credential-based authentication
|
||||
- Optional local caching layer
|
||||
- Efficient metadata retrieval
|
||||
- Support for presigned URLs
|
||||
|
||||
### Other Future Providers
|
||||
- **GitProvider**: Serve files from Git repositories
|
||||
- **MemoryProvider**: In-memory filesystem for testing/temporary files
|
||||
- **ProxyProvider**: Proxy to another static file server
|
||||
- **CompositeProvider**: Fallback chain across multiple providers
|
||||
|
||||
### Advanced Cache Policies (FUTURE)
|
||||
- **TimedCachePolicy**: Different cache times by time of day
|
||||
- **UserAgentCachePolicy**: Cache based on client type
|
||||
- **ConditionalCachePolicy**: Smart cache based on file size/type
|
||||
- **DistributedCachePolicy**: Shared cache across service instances
|
||||
|
||||
### Advanced Fallback Strategies (FUTURE)
|
||||
- **RegexFallbackStrategy**: Pattern-based routing
|
||||
- **I18nFallbackStrategy**: Language-based file resolution
|
||||
- **VersionedFallbackStrategy**: A/B testing support
|
||||
|
||||
## Critical Files Summary
|
||||
|
||||
### Files to CREATE (in order):
|
||||
1. `pkg/server/staticweb/interfaces.go` - Core contracts
|
||||
2. `pkg/server/staticweb/config.go` - Configuration structs and helpers
|
||||
3. `pkg/server/staticweb/policies/cache.go` - Cache policy implementations
|
||||
4. `pkg/server/staticweb/policies/mime.go` - MIME resolver implementations
|
||||
5. `pkg/server/staticweb/policies/fallback.go` - Fallback strategy implementations
|
||||
6. `pkg/server/staticweb/providers/local.go` - Local directory provider
|
||||
7. `pkg/server/staticweb/providers/zip.go` - Zip file provider
|
||||
8. `pkg/server/staticweb/providers/embed.go` - Embedded filesystem provider
|
||||
9. `pkg/server/staticweb/mount.go` - Mount point implementation
|
||||
10. `pkg/server/staticweb/service.go` - Main service implementation
|
||||
11. `pkg/server/staticweb/testing/mocks.go` - Test helpers
|
||||
12. `pkg/server/staticweb/service_test.go` - Service tests
|
||||
13. `pkg/server/staticweb/providers/providers_test.go` - Provider tests
|
||||
14. `pkg/server/staticweb/example_test.go` - Example code
|
||||
15. `pkg/server/staticweb/README.md` - Documentation
|
||||
|
||||
### Files to REMOVE:
|
||||
1. `pkg/server/staticweb/staticweb.go` - Replaced by new design
|
||||
|
||||
### Files to KEEP:
|
||||
1. `pkg/server/zipfs/zipfs.go` - Used by ZipFSProvider
|
||||
|
||||
### Files for FUTURE (not in this refactoring):
|
||||
1. `pkg/server/staticweb/providers/http.go` - HTTP/HTTPS remote provider
|
||||
2. `pkg/server/staticweb/providers/s3.go` - S3-compatible storage provider
|
||||
3. `pkg/server/staticweb/providers/composite.go` - Fallback chain provider
|
||||
|
||||
## Example Usage After Refactoring
|
||||
|
||||
### Basic Static Site
|
||||
```go
|
||||
service := staticweb.NewService(nil) // Use defaults
|
||||
|
||||
err := service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: staticweb.LocalProvider("./public"),
|
||||
})
|
||||
|
||||
muxRouter.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### SPA with API Routes
|
||||
```go
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: staticweb.LocalProvider("./dist"),
|
||||
FallbackStrategy: staticweb.HTMLFallback("index.html"),
|
||||
})
|
||||
|
||||
// API routes take precedence (registered first)
|
||||
muxRouter.HandleFunc("/api/users", usersHandler)
|
||||
muxRouter.HandleFunc("/api/posts", postsHandler)
|
||||
|
||||
// Static files handle all other routes
|
||||
muxRouter.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### Multiple Mount Points with Different Policies
|
||||
```go
|
||||
service := staticweb.NewService(&staticweb.ServiceConfig{
|
||||
DefaultCacheTime: 3600,
|
||||
})
|
||||
|
||||
// Assets with long cache
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/assets",
|
||||
Provider: staticweb.LocalProvider("./assets"),
|
||||
CachePolicy: staticweb.SimpleCache(604800), // 1 week
|
||||
})
|
||||
|
||||
// HTML with short cache
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/",
|
||||
Provider: staticweb.LocalProvider("./public"),
|
||||
CachePolicy: staticweb.SimpleCache(300), // 5 minutes
|
||||
})
|
||||
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### Embedded Files from Zip
|
||||
```go
|
||||
//go:embed assets.zip
|
||||
var assetsZip embed.FS
|
||||
|
||||
service := staticweb.NewService(nil)
|
||||
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: staticweb.EmbedProvider(&assetsZip, "assets.zip"),
|
||||
})
|
||||
|
||||
router.PathPrefix("/").Handler(service.Handler())
|
||||
```
|
||||
|
||||
### Future: CDN Fallback (when HTTP provider is implemented)
|
||||
```go
|
||||
// Primary CDN with local fallback
|
||||
service.Mount(staticweb.MountConfig{
|
||||
URLPrefix: "/static",
|
||||
Provider: staticweb.CompositeProvider(
|
||||
staticweb.HTTPProvider("https://cdn.example.com/assets"),
|
||||
staticweb.LocalProvider("./public/assets"),
|
||||
),
|
||||
})
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
- Each provider implementation independently
|
||||
- Each policy implementation independently
|
||||
- Mount point request handling
|
||||
- Service mount/unmount operations
|
||||
|
||||
### Integration Tests
|
||||
- Full request flow through service
|
||||
- Multiple mount points
|
||||
- Longest-prefix matching
|
||||
- Resource cleanup
|
||||
|
||||
### Example Tests
|
||||
- Executable examples in `example_test.go`
|
||||
- Demonstrate common usage patterns
|
||||
|
||||
## Migration Impact
|
||||
|
||||
### Breaking Changes
|
||||
- Complete API redesign (acceptable per user)
|
||||
- Package not currently used in codebase (no migration needed)
|
||||
- New consumers will use new API from start
|
||||
|
||||
### Future Extensibility
|
||||
The interface-driven design allows future additions without breaking changes:
|
||||
- Add HTTPFSProvider by implementing `FileSystemProvider`
|
||||
- Add S3FSProvider by implementing `FileSystemProvider`
|
||||
- Add custom cache policies by implementing `CachePolicy`
|
||||
- Add custom fallback strategies by implementing `FallbackStrategy`
|
||||
|
||||
## Implementation Order
|
||||
1. Interfaces (foundation)
|
||||
2. Configuration (API surface)
|
||||
3. Policies (pluggable behavior)
|
||||
4. Providers (filesystem abstraction)
|
||||
5. Mount Point (request handling)
|
||||
6. Service (orchestration)
|
||||
7. Tests (validation)
|
||||
8. Documentation (usage)
|
||||
9. Remove old code (cleanup)
|
||||
|
||||
This order ensures each layer builds on tested, working components.
|
||||
|
||||
---
|
||||
|
||||
103
pkg/server/staticweb/policies/cache.go
Normal file
103
pkg/server/staticweb/policies/cache.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package policies
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SimpleCachePolicy implements a basic cache policy with a single TTL for all files.
|
||||
type SimpleCachePolicy struct {
|
||||
cacheTime int // Cache duration in seconds
|
||||
}
|
||||
|
||||
// NewSimpleCachePolicy creates a new SimpleCachePolicy with the given cache time in seconds.
|
||||
func NewSimpleCachePolicy(cacheTimeSeconds int) *SimpleCachePolicy {
|
||||
return &SimpleCachePolicy{
|
||||
cacheTime: cacheTimeSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCacheTime returns the cache duration for any file.
|
||||
func (p *SimpleCachePolicy) GetCacheTime(filePath string) int {
|
||||
return p.cacheTime
|
||||
}
|
||||
|
||||
// GetCacheHeaders returns the Cache-Control header for the given file.
|
||||
func (p *SimpleCachePolicy) GetCacheHeaders(filePath string) map[string]string {
|
||||
if p.cacheTime <= 0 {
|
||||
return map[string]string{
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"Cache-Control": fmt.Sprintf("public, max-age=%d", p.cacheTime),
|
||||
}
|
||||
}
|
||||
|
||||
// ExtensionBasedCachePolicy implements a cache policy that varies by file extension.
|
||||
type ExtensionBasedCachePolicy struct {
|
||||
rules map[string]int // Extension -> cache time in seconds
|
||||
defaultTime int // Default cache time for unmatched extensions
|
||||
}
|
||||
|
||||
// NewExtensionBasedCachePolicy creates a new ExtensionBasedCachePolicy.
|
||||
// rules maps file extensions (with leading dot, e.g., ".js") to cache times in seconds.
|
||||
// defaultTime is used for files that don't match any rule.
|
||||
func NewExtensionBasedCachePolicy(rules map[string]int, defaultTime int) *ExtensionBasedCachePolicy {
|
||||
return &ExtensionBasedCachePolicy{
|
||||
rules: rules,
|
||||
defaultTime: defaultTime,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCacheTime returns the cache duration based on the file extension.
|
||||
func (p *ExtensionBasedCachePolicy) GetCacheTime(filePath string) int {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
if cacheTime, ok := p.rules[ext]; ok {
|
||||
return cacheTime
|
||||
}
|
||||
return p.defaultTime
|
||||
}
|
||||
|
||||
// GetCacheHeaders returns cache headers based on the file extension.
|
||||
func (p *ExtensionBasedCachePolicy) GetCacheHeaders(filePath string) map[string]string {
|
||||
cacheTime := p.GetCacheTime(filePath)
|
||||
|
||||
if cacheTime <= 0 {
|
||||
return map[string]string{
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"Cache-Control": fmt.Sprintf("public, max-age=%d", cacheTime),
|
||||
}
|
||||
}
|
||||
|
||||
// NoCachePolicy implements a cache policy that disables all caching.
|
||||
type NoCachePolicy struct{}
|
||||
|
||||
// NewNoCachePolicy creates a new NoCachePolicy.
|
||||
func NewNoCachePolicy() *NoCachePolicy {
|
||||
return &NoCachePolicy{}
|
||||
}
|
||||
|
||||
// GetCacheTime always returns 0 (no caching).
|
||||
func (p *NoCachePolicy) GetCacheTime(filePath string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetCacheHeaders returns headers that disable caching.
|
||||
func (p *NoCachePolicy) GetCacheHeaders(filePath string) map[string]string {
|
||||
return map[string]string{
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
}
|
||||
}
|
||||
159
pkg/server/staticweb/policies/fallback.go
Normal file
159
pkg/server/staticweb/policies/fallback.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package policies
|
||||
|
||||
import (
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NoFallback implements a fallback strategy that never falls back.
|
||||
// All requests for missing files will result in 404 responses.
|
||||
type NoFallback struct{}
|
||||
|
||||
// NewNoFallback creates a new NoFallback strategy.
|
||||
func NewNoFallback() *NoFallback {
|
||||
return &NoFallback{}
|
||||
}
|
||||
|
||||
// ShouldFallback always returns false.
|
||||
func (f *NoFallback) ShouldFallback(filePath string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetFallbackPath returns an empty string (never called since ShouldFallback returns false).
|
||||
func (f *NoFallback) GetFallbackPath(filePath string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// HTMLFallbackStrategy implements a fallback strategy for Single Page Applications (SPAs).
|
||||
// It serves a specified HTML file (typically index.html) for non-file requests.
|
||||
type HTMLFallbackStrategy struct {
|
||||
indexFile string
|
||||
}
|
||||
|
||||
// NewHTMLFallbackStrategy creates a new HTMLFallbackStrategy.
|
||||
// indexFile is the path to the HTML file to serve (e.g., "index.html", "/index.html").
|
||||
func NewHTMLFallbackStrategy(indexFile string) *HTMLFallbackStrategy {
|
||||
return &HTMLFallbackStrategy{
|
||||
indexFile: indexFile,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldFallback returns true for requests that don't look like static assets.
|
||||
func (f *HTMLFallbackStrategy) ShouldFallback(filePath string) bool {
|
||||
// Always fall back unless it looks like a static asset
|
||||
return !f.isStaticAsset(filePath)
|
||||
}
|
||||
|
||||
// GetFallbackPath returns the index file path.
|
||||
func (f *HTMLFallbackStrategy) GetFallbackPath(filePath string) string {
|
||||
return f.indexFile
|
||||
}
|
||||
|
||||
// isStaticAsset checks if the path looks like a static asset (has a file extension).
|
||||
func (f *HTMLFallbackStrategy) isStaticAsset(filePath string) bool {
|
||||
return path.Ext(filePath) != ""
|
||||
}
|
||||
|
||||
// ExtensionBasedFallback implements a fallback strategy that skips fallback for known static file extensions.
|
||||
// This is the behavior from the original StaticHTMLFallbackHandler.
|
||||
type ExtensionBasedFallback struct {
|
||||
staticExtensions map[string]bool
|
||||
fallbackPath string
|
||||
}
|
||||
|
||||
// NewExtensionBasedFallback creates a new ExtensionBasedFallback strategy.
|
||||
// staticExtensions is a list of file extensions (with leading dot) that should NOT use fallback.
|
||||
// fallbackPath is the file to serve when fallback is triggered.
|
||||
func NewExtensionBasedFallback(staticExtensions []string, fallbackPath string) *ExtensionBasedFallback {
|
||||
extMap := make(map[string]bool)
|
||||
for _, ext := range staticExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
}
|
||||
extMap[strings.ToLower(ext)] = true
|
||||
}
|
||||
|
||||
return &ExtensionBasedFallback{
|
||||
staticExtensions: extMap,
|
||||
fallbackPath: fallbackPath,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDefaultExtensionBasedFallback creates an ExtensionBasedFallback with common web asset extensions.
|
||||
// This matches the behavior of the original StaticHTMLFallbackHandler.
|
||||
func NewDefaultExtensionBasedFallback(fallbackPath string) *ExtensionBasedFallback {
|
||||
return NewExtensionBasedFallback([]string{
|
||||
".js", ".css", ".png", ".svg", ".ico", ".json",
|
||||
".jpg", ".jpeg", ".gif", ".woff", ".woff2", ".ttf", ".eot",
|
||||
}, fallbackPath)
|
||||
}
|
||||
|
||||
// ShouldFallback returns true if the file path doesn't have a static asset extension.
|
||||
func (f *ExtensionBasedFallback) ShouldFallback(filePath string) bool {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
|
||||
// If it's a known static extension, don't fallback
|
||||
if f.staticExtensions[ext] {
|
||||
return false
|
||||
}
|
||||
|
||||
// Otherwise, try fallback
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFallbackPath returns the configured fallback path.
|
||||
func (f *ExtensionBasedFallback) GetFallbackPath(filePath string) string {
|
||||
return f.fallbackPath
|
||||
}
|
||||
|
||||
// HTMLExtensionFallback implements a fallback strategy that appends .html to paths.
|
||||
// This tries to serve {path}.html for missing files.
|
||||
type HTMLExtensionFallback struct {
|
||||
staticExtensions map[string]bool
|
||||
}
|
||||
|
||||
// NewHTMLExtensionFallback creates a new HTMLExtensionFallback strategy.
|
||||
func NewHTMLExtensionFallback(staticExtensions []string) *HTMLExtensionFallback {
|
||||
extMap := make(map[string]bool)
|
||||
for _, ext := range staticExtensions {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
}
|
||||
extMap[strings.ToLower(ext)] = true
|
||||
}
|
||||
|
||||
return &HTMLExtensionFallback{
|
||||
staticExtensions: extMap,
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldFallback returns true if the path doesn't have a static extension or .html.
|
||||
func (f *HTMLExtensionFallback) ShouldFallback(filePath string) bool {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
|
||||
// If it's a known static extension, don't fallback
|
||||
if f.staticExtensions[ext] {
|
||||
return false
|
||||
}
|
||||
|
||||
// If it already has .html, don't fallback
|
||||
if ext == ".html" || ext == ".htm" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFallbackPath returns the path with .html appended.
|
||||
func (f *HTMLExtensionFallback) GetFallbackPath(filePath string) string {
|
||||
cleanPath := path.Clean(filePath)
|
||||
if !strings.HasSuffix(filePath, "/") {
|
||||
cleanPath = strings.TrimRight(cleanPath, "/")
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(strings.ToLower(cleanPath), ".html") {
|
||||
return cleanPath + ".html"
|
||||
}
|
||||
|
||||
return cleanPath
|
||||
}
|
||||
245
pkg/server/staticweb/policies/mime.go
Normal file
245
pkg/server/staticweb/policies/mime.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package policies
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DefaultMIMEResolver implements a MIME type resolver using Go's standard mime package
|
||||
// and a set of common web file type mappings.
|
||||
type DefaultMIMEResolver struct {
|
||||
customTypes map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDefaultMIMEResolver creates a new DefaultMIMEResolver with common web MIME types.
|
||||
func NewDefaultMIMEResolver() *DefaultMIMEResolver {
|
||||
resolver := &DefaultMIMEResolver{
|
||||
customTypes: make(map[string]string),
|
||||
}
|
||||
|
||||
// JavaScript & TypeScript
|
||||
resolver.RegisterMIMEType(".js", "application/javascript")
|
||||
resolver.RegisterMIMEType(".mjs", "application/javascript")
|
||||
resolver.RegisterMIMEType(".cjs", "application/javascript")
|
||||
resolver.RegisterMIMEType(".ts", "text/typescript")
|
||||
resolver.RegisterMIMEType(".tsx", "text/tsx")
|
||||
resolver.RegisterMIMEType(".jsx", "text/jsx")
|
||||
|
||||
// CSS & Styling
|
||||
resolver.RegisterMIMEType(".css", "text/css")
|
||||
resolver.RegisterMIMEType(".scss", "text/x-scss")
|
||||
resolver.RegisterMIMEType(".sass", "text/x-sass")
|
||||
resolver.RegisterMIMEType(".less", "text/x-less")
|
||||
|
||||
// HTML & XML
|
||||
resolver.RegisterMIMEType(".html", "text/html")
|
||||
resolver.RegisterMIMEType(".htm", "text/html")
|
||||
resolver.RegisterMIMEType(".xml", "application/xml")
|
||||
resolver.RegisterMIMEType(".xhtml", "application/xhtml+xml")
|
||||
|
||||
// Images - Raster
|
||||
resolver.RegisterMIMEType(".png", "image/png")
|
||||
resolver.RegisterMIMEType(".jpg", "image/jpeg")
|
||||
resolver.RegisterMIMEType(".jpeg", "image/jpeg")
|
||||
resolver.RegisterMIMEType(".gif", "image/gif")
|
||||
resolver.RegisterMIMEType(".webp", "image/webp")
|
||||
resolver.RegisterMIMEType(".avif", "image/avif")
|
||||
resolver.RegisterMIMEType(".bmp", "image/bmp")
|
||||
resolver.RegisterMIMEType(".tiff", "image/tiff")
|
||||
resolver.RegisterMIMEType(".tif", "image/tiff")
|
||||
resolver.RegisterMIMEType(".ico", "image/x-icon")
|
||||
resolver.RegisterMIMEType(".cur", "image/x-icon")
|
||||
|
||||
// Images - Vector
|
||||
resolver.RegisterMIMEType(".svg", "image/svg+xml")
|
||||
resolver.RegisterMIMEType(".svgz", "image/svg+xml")
|
||||
|
||||
// Fonts
|
||||
resolver.RegisterMIMEType(".woff", "font/woff")
|
||||
resolver.RegisterMIMEType(".woff2", "font/woff2")
|
||||
resolver.RegisterMIMEType(".ttf", "font/ttf")
|
||||
resolver.RegisterMIMEType(".otf", "font/otf")
|
||||
resolver.RegisterMIMEType(".eot", "application/vnd.ms-fontobject")
|
||||
|
||||
// Audio
|
||||
resolver.RegisterMIMEType(".mp3", "audio/mpeg")
|
||||
resolver.RegisterMIMEType(".wav", "audio/wav")
|
||||
resolver.RegisterMIMEType(".ogg", "audio/ogg")
|
||||
resolver.RegisterMIMEType(".oga", "audio/ogg")
|
||||
resolver.RegisterMIMEType(".m4a", "audio/mp4")
|
||||
resolver.RegisterMIMEType(".aac", "audio/aac")
|
||||
resolver.RegisterMIMEType(".flac", "audio/flac")
|
||||
resolver.RegisterMIMEType(".opus", "audio/opus")
|
||||
resolver.RegisterMIMEType(".weba", "audio/webm")
|
||||
|
||||
// Video
|
||||
resolver.RegisterMIMEType(".mp4", "video/mp4")
|
||||
resolver.RegisterMIMEType(".webm", "video/webm")
|
||||
resolver.RegisterMIMEType(".ogv", "video/ogg")
|
||||
resolver.RegisterMIMEType(".avi", "video/x-msvideo")
|
||||
resolver.RegisterMIMEType(".mpeg", "video/mpeg")
|
||||
resolver.RegisterMIMEType(".mpg", "video/mpeg")
|
||||
resolver.RegisterMIMEType(".mov", "video/quicktime")
|
||||
resolver.RegisterMIMEType(".wmv", "video/x-ms-wmv")
|
||||
resolver.RegisterMIMEType(".flv", "video/x-flv")
|
||||
resolver.RegisterMIMEType(".mkv", "video/x-matroska")
|
||||
resolver.RegisterMIMEType(".m4v", "video/mp4")
|
||||
|
||||
// Data & Configuration
|
||||
resolver.RegisterMIMEType(".json", "application/json")
|
||||
resolver.RegisterMIMEType(".xml", "application/xml")
|
||||
resolver.RegisterMIMEType(".yml", "application/yaml")
|
||||
resolver.RegisterMIMEType(".yaml", "application/yaml")
|
||||
resolver.RegisterMIMEType(".toml", "application/toml")
|
||||
resolver.RegisterMIMEType(".ini", "text/plain")
|
||||
resolver.RegisterMIMEType(".conf", "text/plain")
|
||||
resolver.RegisterMIMEType(".config", "text/plain")
|
||||
|
||||
// Documents
|
||||
resolver.RegisterMIMEType(".pdf", "application/pdf")
|
||||
resolver.RegisterMIMEType(".doc", "application/msword")
|
||||
resolver.RegisterMIMEType(".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
||||
resolver.RegisterMIMEType(".xls", "application/vnd.ms-excel")
|
||||
resolver.RegisterMIMEType(".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
|
||||
resolver.RegisterMIMEType(".ppt", "application/vnd.ms-powerpoint")
|
||||
resolver.RegisterMIMEType(".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
|
||||
resolver.RegisterMIMEType(".odt", "application/vnd.oasis.opendocument.text")
|
||||
resolver.RegisterMIMEType(".ods", "application/vnd.oasis.opendocument.spreadsheet")
|
||||
resolver.RegisterMIMEType(".odp", "application/vnd.oasis.opendocument.presentation")
|
||||
|
||||
// Archives
|
||||
resolver.RegisterMIMEType(".zip", "application/zip")
|
||||
resolver.RegisterMIMEType(".tar", "application/x-tar")
|
||||
resolver.RegisterMIMEType(".gz", "application/gzip")
|
||||
resolver.RegisterMIMEType(".bz2", "application/x-bzip2")
|
||||
resolver.RegisterMIMEType(".7z", "application/x-7z-compressed")
|
||||
resolver.RegisterMIMEType(".rar", "application/vnd.rar")
|
||||
|
||||
// Text files
|
||||
resolver.RegisterMIMEType(".txt", "text/plain")
|
||||
resolver.RegisterMIMEType(".md", "text/markdown")
|
||||
resolver.RegisterMIMEType(".markdown", "text/markdown")
|
||||
resolver.RegisterMIMEType(".csv", "text/csv")
|
||||
resolver.RegisterMIMEType(".log", "text/plain")
|
||||
|
||||
// Source code (for syntax highlighting in browsers)
|
||||
resolver.RegisterMIMEType(".c", "text/x-c")
|
||||
resolver.RegisterMIMEType(".cpp", "text/x-c++")
|
||||
resolver.RegisterMIMEType(".h", "text/x-c")
|
||||
resolver.RegisterMIMEType(".hpp", "text/x-c++")
|
||||
resolver.RegisterMIMEType(".go", "text/x-go")
|
||||
resolver.RegisterMIMEType(".py", "text/x-python")
|
||||
resolver.RegisterMIMEType(".java", "text/x-java")
|
||||
resolver.RegisterMIMEType(".rs", "text/x-rust")
|
||||
resolver.RegisterMIMEType(".rb", "text/x-ruby")
|
||||
resolver.RegisterMIMEType(".php", "text/x-php")
|
||||
resolver.RegisterMIMEType(".sh", "text/x-shellscript")
|
||||
resolver.RegisterMIMEType(".bash", "text/x-shellscript")
|
||||
resolver.RegisterMIMEType(".sql", "text/x-sql")
|
||||
resolver.RegisterMIMEType(".template.sql", "text/plain")
|
||||
resolver.RegisterMIMEType(".upg", "text/plain")
|
||||
|
||||
// Web Assembly
|
||||
resolver.RegisterMIMEType(".wasm", "application/wasm")
|
||||
|
||||
// Manifest & Service Worker
|
||||
resolver.RegisterMIMEType(".webmanifest", "application/manifest+json")
|
||||
resolver.RegisterMIMEType(".manifest", "text/cache-manifest")
|
||||
|
||||
// 3D Models
|
||||
resolver.RegisterMIMEType(".gltf", "model/gltf+json")
|
||||
resolver.RegisterMIMEType(".glb", "model/gltf-binary")
|
||||
resolver.RegisterMIMEType(".obj", "model/obj")
|
||||
resolver.RegisterMIMEType(".stl", "model/stl")
|
||||
|
||||
// Other common web assets
|
||||
resolver.RegisterMIMEType(".map", "application/json") // Source maps
|
||||
resolver.RegisterMIMEType(".swf", "application/x-shockwave-flash")
|
||||
resolver.RegisterMIMEType(".apk", "application/vnd.android.package-archive")
|
||||
resolver.RegisterMIMEType(".dmg", "application/x-apple-diskimage")
|
||||
resolver.RegisterMIMEType(".exe", "application/x-msdownload")
|
||||
resolver.RegisterMIMEType(".iso", "application/x-iso9660-image")
|
||||
|
||||
return resolver
|
||||
}
|
||||
|
||||
// GetMIMEType returns the MIME type for the given file path.
|
||||
// It first checks custom registered types, then falls back to Go's mime.TypeByExtension.
|
||||
func (r *DefaultMIMEResolver) GetMIMEType(filePath string) string {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
|
||||
// Check custom types first
|
||||
r.mu.RLock()
|
||||
if mimeType, ok := r.customTypes[ext]; ok {
|
||||
r.mu.RUnlock()
|
||||
return mimeType
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
// Fall back to standard library
|
||||
if mimeType := mime.TypeByExtension(ext); mimeType != "" {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
// Return empty string if unknown
|
||||
return ""
|
||||
}
|
||||
|
||||
// RegisterMIMEType registers a custom MIME type for the given file extension.
|
||||
func (r *DefaultMIMEResolver) RegisterMIMEType(extension, mimeType string) {
|
||||
if !strings.HasPrefix(extension, ".") {
|
||||
extension = "." + extension
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.customTypes[strings.ToLower(extension)] = mimeType
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// ConfigurableMIMEResolver implements a MIME type resolver with user-defined mappings only.
|
||||
// It does not use any default mappings.
|
||||
type ConfigurableMIMEResolver struct {
|
||||
types map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewConfigurableMIMEResolver creates a new ConfigurableMIMEResolver with the given mappings.
|
||||
func NewConfigurableMIMEResolver(types map[string]string) *ConfigurableMIMEResolver {
|
||||
resolver := &ConfigurableMIMEResolver{
|
||||
types: make(map[string]string),
|
||||
}
|
||||
|
||||
for ext, mimeType := range types {
|
||||
resolver.RegisterMIMEType(ext, mimeType)
|
||||
}
|
||||
|
||||
return resolver
|
||||
}
|
||||
|
||||
// GetMIMEType returns the MIME type for the given file path.
|
||||
func (r *ConfigurableMIMEResolver) GetMIMEType(filePath string) string {
|
||||
ext := strings.ToLower(path.Ext(filePath))
|
||||
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if mimeType, ok := r.types[ext]; ok {
|
||||
return mimeType
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RegisterMIMEType registers a MIME type for the given file extension.
|
||||
func (r *ConfigurableMIMEResolver) RegisterMIMEType(extension, mimeType string) {
|
||||
if !strings.HasPrefix(extension, ".") {
|
||||
extension = "." + extension
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.types[strings.ToLower(extension)] = mimeType
|
||||
r.mu.Unlock()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user