mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4a6f9c4c2 | ||
| 8f83e8fdc1 | |||
|
|
ed67caf055 | ||
| 4d1b8b6982 | |||
|
|
63ed62a9a3 | ||
|
|
0525323a47 | ||
|
|
c3443f702e | ||
|
|
45c463c117 | ||
|
|
84d673ce14 | ||
|
|
02fbdbd651 | ||
|
|
97988e3b5e | ||
|
|
c9838ad9d2 | ||
|
|
c5c0608f63 | ||
|
|
39c3f05d21 | ||
|
|
4ecd1ac17e | ||
|
|
2b1aea0338 | ||
|
|
1e749efeb3 | ||
|
|
09be676096 | ||
|
|
e8350a70be | ||
|
|
5937b9eab5 | ||
|
|
7c861c708e | ||
|
|
77f39af2f9 | ||
|
|
fbc1471581 | ||
|
|
9351093e2a | ||
|
|
932f12ab0a | ||
|
|
b22792bad6 | ||
|
|
e8111c01aa | ||
|
|
5862016031 | ||
|
|
2f18dde29c | ||
|
|
31ad217818 | ||
|
|
7ef1d6424a |
6
Makefile
6
Makefile
@@ -16,7 +16,7 @@ test: test-unit test-integration
|
|||||||
# Start PostgreSQL for integration tests
|
# Start PostgreSQL for integration tests
|
||||||
docker-up:
|
docker-up:
|
||||||
@echo "Starting PostgreSQL container..."
|
@echo "Starting PostgreSQL container..."
|
||||||
@docker-compose up -d postgres-test
|
@podman compose up -d postgres-test
|
||||||
@echo "Waiting for PostgreSQL to be ready..."
|
@echo "Waiting for PostgreSQL to be ready..."
|
||||||
@sleep 5
|
@sleep 5
|
||||||
@echo "PostgreSQL is ready!"
|
@echo "PostgreSQL is ready!"
|
||||||
@@ -24,12 +24,12 @@ docker-up:
|
|||||||
# Stop PostgreSQL container
|
# Stop PostgreSQL container
|
||||||
docker-down:
|
docker-down:
|
||||||
@echo "Stopping PostgreSQL container..."
|
@echo "Stopping PostgreSQL container..."
|
||||||
@docker-compose down
|
@podman compose down
|
||||||
|
|
||||||
# Clean up Docker volumes and test data
|
# Clean up Docker volumes and test data
|
||||||
clean:
|
clean:
|
||||||
@echo "Cleaning up..."
|
@echo "Cleaning up..."
|
||||||
@docker-compose down -v
|
@podman compose down -v
|
||||||
@echo "Cleanup complete!"
|
@echo "Cleanup complete!"
|
||||||
|
|
||||||
# Run integration tests with Docker (full workflow)
|
# Run integration tests with Docker (full workflow)
|
||||||
|
|||||||
68
go.mod
68
go.mod
@@ -11,15 +11,17 @@ require (
|
|||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
|
github.com/jackc/pgx/v5 v5.6.0
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/redis/go-redis/v9 v9.17.1
|
github.com/redis/go-redis/v9 v9.17.1
|
||||||
github.com/spf13/viper v1.21.0
|
github.com/spf13/viper v1.21.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/uptrace/bun v1.2.15
|
github.com/uptrace/bun v1.2.16
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
||||||
github.com/uptrace/bunrouter v1.0.23
|
github.com/uptrace/bunrouter v1.0.23
|
||||||
go.opentelemetry.io/otel v1.38.0
|
go.opentelemetry.io/otel v1.38.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
||||||
@@ -33,36 +35,68 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
dario.cat/mergo v1.0.2 // indirect
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
|
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||||
|
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||||
|
github.com/moby/sys/user v0.4.0 // indirect
|
||||||
|
github.com/moby/sys/userns v0.1.0 // indirect
|
||||||
|
github.com/moby/term v0.5.0 // indirect
|
||||||
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
github.com/prometheus/common v0.66.1 // indirect
|
github.com/prometheus/common v0.66.1 // indirect
|
||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||||
github.com/spf13/afero v1.15.0 // indirect
|
github.com/spf13/afero v1.15.0 // indirect
|
||||||
github.com/spf13/cast v1.10.0 // indirect
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
@@ -70,28 +104,34 @@ require (
|
|||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||||
go.uber.org/multierr v1.10.0 // indirect
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
golang.org/x/crypto v0.41.0 // indirect
|
golang.org/x/crypto v0.43.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
|
||||||
golang.org/x/net v0.43.0 // indirect
|
golang.org/x/net v0.45.0 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
golang.org/x/sync v0.18.0 // indirect
|
||||||
golang.org/x/sys v0.35.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.28.0 // indirect
|
golang.org/x/text v0.30.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||||
google.golang.org/grpc v1.75.0 // indirect
|
google.golang.org/grpc v1.75.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.8 // indirect
|
google.golang.org/protobuf v1.36.8 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
modernc.org/libc v1.66.3 // indirect
|
modernc.org/libc v1.67.0 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
modernc.org/sqlite v1.38.0 // indirect
|
modernc.org/sqlite v1.40.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||||
|
|||||||
170
go.sum
170
go.sum
@@ -1,5 +1,13 @@
|
|||||||
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||||
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||||
@@ -8,17 +16,43 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
|||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
|
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||||
|
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||||
|
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||||
|
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
@@ -36,10 +70,13 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
|||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
@@ -50,6 +87,8 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
|||||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -71,14 +110,40 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||||
|
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||||
|
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||||
|
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||||
|
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||||
|
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||||
|
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||||
|
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||||
|
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||||
|
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||||
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||||
@@ -87,6 +152,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
@@ -105,6 +172,10 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
|||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||||
@@ -116,12 +187,16 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A
|
|||||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
@@ -131,28 +206,38 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
|||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
||||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
|
||||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
|
github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||||
|
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||||
@@ -173,25 +258,34 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
|||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||||
|
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||||
|
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||||
|
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
||||||
@@ -212,18 +306,22 @@ gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
|||||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||||
|
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||||
|
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||||
|
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
||||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
@@ -232,8 +330,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
20
pkg/cache/cache_manager.go
vendored
20
pkg/cache/cache_manager.go
vendored
@@ -57,11 +57,31 @@ func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time
|
|||||||
return c.provider.Set(ctx, key, value, ttl)
|
return c.provider.Set(ctx, key, value, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags serializes and stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (c *Cache) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags []string) error {
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to serialize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.provider.SetWithTags(ctx, key, data, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBytesWithTags stores raw bytes in the cache with the specified TTL and tags.
|
||||||
|
func (c *Cache) SetBytesWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
return c.provider.SetWithTags(ctx, key, value, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (c *Cache) Delete(ctx context.Context, key string) error {
|
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||||
return c.provider.Delete(ctx, key)
|
return c.provider.Delete(ctx, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (c *Cache) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
return c.provider.DeleteByTag(ctx, tag)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
return c.provider.DeleteByPattern(ctx, pattern)
|
return c.provider.DeleteByPattern(ctx, pattern)
|
||||||
|
|||||||
8
pkg/cache/provider.go
vendored
8
pkg/cache/provider.go
vendored
@@ -15,9 +15,17 @@ type Provider interface {
|
|||||||
// If ttl is 0, the item never expires.
|
// If ttl is 0, the item never expires.
|
||||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
// Tags can be used to invalidate groups of related keys.
|
||||||
|
// If ttl is 0, the item never expires.
|
||||||
|
SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
Delete(ctx context.Context, key string) error
|
Delete(ctx context.Context, key string) error
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
DeleteByTag(ctx context.Context, tag string) error
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
// Pattern syntax depends on the provider implementation.
|
// Pattern syntax depends on the provider implementation.
|
||||||
DeleteByPattern(ctx context.Context, pattern string) error
|
DeleteByPattern(ctx context.Context, pattern string) error
|
||||||
|
|||||||
140
pkg/cache/provider_memcache.go
vendored
140
pkg/cache/provider_memcache.go
vendored
@@ -2,6 +2,7 @@ package cache
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -97,8 +98,115 @@ func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, tt
|
|||||||
return m.client.Set(item)
|
return m.client.Set(item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
// Note: Tag support in Memcache is limited and less efficient than Redis.
|
||||||
|
func (m *MemcacheProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration := int32(ttl.Seconds())
|
||||||
|
|
||||||
|
// Set the main value
|
||||||
|
item := &memcache.Item{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
if err := m.client.Set(item); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store tags for this key
|
||||||
|
if len(tags) > 0 {
|
||||||
|
tagsData, err := json.Marshal(tags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tagsItem := &memcache.Item{
|
||||||
|
Key: fmt.Sprintf("cache:tags:%s", key),
|
||||||
|
Value: tagsData,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
if err := m.client.Set(tagsItem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add key to each tag's key list
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get existing keys for this tag
|
||||||
|
var keys []string
|
||||||
|
if item, err := m.client.Get(tagKey); err == nil {
|
||||||
|
_ = json.Unmarshal(item.Value, &keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add current key if not already present
|
||||||
|
found := false
|
||||||
|
for _, k := range keys {
|
||||||
|
if k == key {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store updated key list
|
||||||
|
keysData, err := json.Marshal(keys)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tagItem := &memcache.Item{
|
||||||
|
Key: tagKey,
|
||||||
|
Value: keysData,
|
||||||
|
Expiration: expiration + 3600, // Give tag lists longer TTL
|
||||||
|
}
|
||||||
|
_ = m.client.Set(tagItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||||
|
// Get tags for this key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
if item, err := m.client.Get(tagsKey); err == nil {
|
||||||
|
var tags []string
|
||||||
|
if err := json.Unmarshal(item.Value, &tags); err == nil {
|
||||||
|
// Remove key from each tag's key list
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
if tagItem, err := m.client.Get(tagKey); err == nil {
|
||||||
|
var keys []string
|
||||||
|
if err := json.Unmarshal(tagItem.Value, &keys); err == nil {
|
||||||
|
// Remove current key from the list
|
||||||
|
newKeys := make([]string, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
if k != key {
|
||||||
|
newKeys = append(newKeys, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the tag's key list
|
||||||
|
if keysData, err := json.Marshal(newKeys); err == nil {
|
||||||
|
tagItem.Value = keysData
|
||||||
|
_ = m.client.Set(tagItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Delete the tags key
|
||||||
|
_ = m.client.Delete(tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the actual key
|
||||||
err := m.client.Delete(key)
|
err := m.client.Delete(key)
|
||||||
if err == memcache.ErrCacheMiss {
|
if err == memcache.ErrCacheMiss {
|
||||||
return nil
|
return nil
|
||||||
@@ -106,6 +214,38 @@ func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (m *MemcacheProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
item, err := m.client.Get(tagKey)
|
||||||
|
if err == memcache.ErrCacheMiss {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
if err := json.Unmarshal(item.Value, &keys); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal tag keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete all keys
|
||||||
|
for _, key := range keys {
|
||||||
|
_ = m.client.Delete(key)
|
||||||
|
// Also delete the tags key for this cache key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
_ = m.client.Delete(tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the tag key itself
|
||||||
|
_ = m.client.Delete(tagKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
// Note: Memcache does not support pattern-based deletion natively.
|
// Note: Memcache does not support pattern-based deletion natively.
|
||||||
// This is a no-op for memcache and returns an error.
|
// This is a no-op for memcache and returns an error.
|
||||||
|
|||||||
104
pkg/cache/provider_memory.go
vendored
104
pkg/cache/provider_memory.go
vendored
@@ -15,6 +15,7 @@ type memoryItem struct {
|
|||||||
Expiration time.Time
|
Expiration time.Time
|
||||||
LastAccess time.Time
|
LastAccess time.Time
|
||||||
HitCount int64
|
HitCount int64
|
||||||
|
Tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// isExpired checks if the item has expired.
|
// isExpired checks if the item has expired.
|
||||||
@@ -29,6 +30,7 @@ func (m *memoryItem) isExpired() bool {
|
|||||||
type MemoryProvider struct {
|
type MemoryProvider struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
items map[string]*memoryItem
|
items map[string]*memoryItem
|
||||||
|
tagToKeys map[string]map[string]struct{} // tag -> set of keys
|
||||||
options *Options
|
options *Options
|
||||||
hits atomic.Int64
|
hits atomic.Int64
|
||||||
misses atomic.Int64
|
misses atomic.Int64
|
||||||
@@ -45,6 +47,7 @@ func NewMemoryProvider(opts *Options) *MemoryProvider {
|
|||||||
|
|
||||||
return &MemoryProvider{
|
return &MemoryProvider{
|
||||||
items: make(map[string]*memoryItem),
|
items: make(map[string]*memoryItem),
|
||||||
|
tagToKeys: make(map[string]map[string]struct{}),
|
||||||
options: opts,
|
options: opts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,15 +117,116 @@ func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (m *MemoryProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiration time.Time
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check max size and evict if necessary
|
||||||
|
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||||
|
if _, exists := m.items[key]; !exists {
|
||||||
|
m.evictOne()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old tag associations if key exists
|
||||||
|
if oldItem, exists := m.items[key]; exists {
|
||||||
|
for _, tag := range oldItem.Tags {
|
||||||
|
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||||
|
delete(keySet, key)
|
||||||
|
if len(keySet) == 0 {
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the item
|
||||||
|
m.items[key] = &memoryItem{
|
||||||
|
Value: value,
|
||||||
|
Expiration: expiration,
|
||||||
|
LastAccess: time.Now(),
|
||||||
|
Tags: tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new tag associations
|
||||||
|
for _, tag := range tags {
|
||||||
|
if m.tagToKeys[tag] == nil {
|
||||||
|
m.tagToKeys[tag] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
m.tagToKeys[tag][key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove tag associations
|
||||||
|
if item, exists := m.items[key]; exists {
|
||||||
|
for _, tag := range item.Tags {
|
||||||
|
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||||
|
delete(keySet, key)
|
||||||
|
if len(keySet) == 0 {
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(m.items, key)
|
delete(m.items, key)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (m *MemoryProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
keySet, exists := m.tagToKeys[tag]
|
||||||
|
if !exists {
|
||||||
|
return nil // No keys with this tag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete all items with this tag
|
||||||
|
for key := range keySet {
|
||||||
|
if item, ok := m.items[key]; ok {
|
||||||
|
// Remove this tag from the item's tag list
|
||||||
|
newTags := make([]string, 0, len(item.Tags))
|
||||||
|
for _, t := range item.Tags {
|
||||||
|
if t != tag {
|
||||||
|
newTags = append(newTags, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If item has no more tags, delete it
|
||||||
|
// Otherwise update its tags
|
||||||
|
if len(newTags) == 0 {
|
||||||
|
delete(m.items, key)
|
||||||
|
} else {
|
||||||
|
item.Tags = newTags
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the tag mapping
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
|||||||
86
pkg/cache/provider_redis.go
vendored
86
pkg/cache/provider_redis.go
vendored
@@ -103,9 +103,93 @@ func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl t
|
|||||||
return r.client.Set(ctx, key, value, ttl).Err()
|
return r.client.Set(ctx, key, value, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (r *RedisProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = r.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Set the value
|
||||||
|
pipe.Set(ctx, key, value, ttl)
|
||||||
|
|
||||||
|
// Add key to each tag's set
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
pipe.SAdd(ctx, tagKey, key)
|
||||||
|
// Set expiration on tag set (longer than cache items to ensure cleanup)
|
||||||
|
if ttl > 0 {
|
||||||
|
pipe.Expire(ctx, tagKey, ttl+time.Hour)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store tags for this key for later cleanup
|
||||||
|
if len(tags) > 0 {
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
pipe.SAdd(ctx, tagsKey, tags)
|
||||||
|
if ttl > 0 {
|
||||||
|
pipe.Expire(ctx, tagsKey, ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||||
return r.client.Del(ctx, key).Err()
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Get tags for this key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
tags, err := r.client.SMembers(ctx, tagsKey).Result()
|
||||||
|
if err == nil && len(tags) > 0 {
|
||||||
|
// Remove key from each tag set
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
pipe.SRem(ctx, tagKey, key)
|
||||||
|
}
|
||||||
|
// Delete the tags key
|
||||||
|
pipe.Del(ctx, tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the actual key
|
||||||
|
pipe.Del(ctx, key)
|
||||||
|
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (r *RedisProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
keys, err := r.client.SMembers(ctx, tagKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Delete all keys and their tag associations
|
||||||
|
for _, key := range keys {
|
||||||
|
pipe.Del(ctx, key)
|
||||||
|
// Also delete the tags key for this cache key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
pipe.Del(ctx, tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the tag set itself
|
||||||
|
pipe.Del(ctx, tagKey)
|
||||||
|
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
|||||||
151
pkg/cache/query_cache_test.go
vendored
151
pkg/cache/query_cache_test.go
vendored
@@ -1,151 +0,0 @@
|
|||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildQueryCacheKey(t *testing.T) {
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "name", Operator: "eq", Value: "test"},
|
|
||||||
{Column: "age", Operator: "gt", Value: 25},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "name", Direction: "asc"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate cache key
|
|
||||||
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
|
||||||
|
|
||||||
// Same parameters should generate same key
|
|
||||||
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
|
||||||
|
|
||||||
if key1 != key2 {
|
|
||||||
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different parameters should generate different key
|
|
||||||
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
|
||||||
|
|
||||||
if key1 == key3 {
|
|
||||||
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "name", Operator: "eq", Value: "test"},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "name", Direction: "asc"},
|
|
||||||
}
|
|
||||||
expandOpts := []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"relation": "posts",
|
|
||||||
"where": "status = 'published'",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate cache key
|
|
||||||
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
|
||||||
|
|
||||||
// Same parameters should generate same key
|
|
||||||
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
|
||||||
|
|
||||||
if key1 != key2 {
|
|
||||||
t.Errorf("Expected same cache keys for identical parameters")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different distinct value should generate different key
|
|
||||||
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
|
||||||
|
|
||||||
if key1 == key3 {
|
|
||||||
t.Errorf("Expected different cache keys for different distinct values")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetQueryTotalCacheKey(t *testing.T) {
|
|
||||||
hash := "abc123"
|
|
||||||
key := GetQueryTotalCacheKey(hash)
|
|
||||||
|
|
||||||
expected := "query_total:abc123"
|
|
||||||
if key != expected {
|
|
||||||
t.Errorf("Expected %s, got %s", expected, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCachedTotalIntegration(t *testing.T) {
|
|
||||||
// Initialize cache with memory provider for testing
|
|
||||||
UseMemory(&Options{
|
|
||||||
DefaultTTL: 1 * time.Minute,
|
|
||||||
MaxSize: 100,
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Create test data
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "status", Operator: "eq", Value: "active"},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "created_at", Direction: "desc"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build cache key
|
|
||||||
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
|
||||||
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
|
||||||
|
|
||||||
// Store a total count in cache
|
|
||||||
totalToCache := CachedTotal{Total: 42}
|
|
||||||
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve from cache
|
|
||||||
var cachedTotal CachedTotal
|
|
||||||
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get from cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cachedTotal.Total != 42 {
|
|
||||||
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test cache miss
|
|
||||||
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
|
||||||
var missedTotal CachedTotal
|
|
||||||
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error for cache miss, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHashString(t *testing.T) {
|
|
||||||
input1 := "test string"
|
|
||||||
input2 := "test string"
|
|
||||||
input3 := "different string"
|
|
||||||
|
|
||||||
hash1 := hashString(input1)
|
|
||||||
hash2 := hashString(input2)
|
|
||||||
hash3 := hashString(input3)
|
|
||||||
|
|
||||||
// Same input should produce same hash
|
|
||||||
if hash1 != hash2 {
|
|
||||||
t.Errorf("Expected same hash for identical inputs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different input should produce different hash
|
|
||||||
if hash1 == hash3 {
|
|
||||||
t.Errorf("Expected different hash for different inputs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash should be hex encoded SHA256 (64 characters)
|
|
||||||
if len(hash1) != 64 {
|
|
||||||
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -196,6 +196,10 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return b.db
|
||||||
|
}
|
||||||
|
|
||||||
// BunSelectQuery implements SelectQuery for Bun
|
// BunSelectQuery implements SelectQuery for Bun
|
||||||
type BunSelectQuery struct {
|
type BunSelectQuery struct {
|
||||||
query *bun.SelectQuery
|
query *bun.SelectQuery
|
||||||
@@ -687,6 +691,11 @@ func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||||
|
b.query = b.query.OrderExpr(order, args...)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
|
func (b *BunSelectQuery) Limit(n int) common.SelectQuery {
|
||||||
b.query = b.query.Limit(n)
|
b.query = b.query.Limit(n)
|
||||||
return b
|
return b
|
||||||
@@ -1208,3 +1217,7 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||||
return fn(b) // Already in transaction
|
return fn(b) // Already in transaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return b.tx
|
||||||
|
}
|
||||||
|
|||||||
@@ -102,6 +102,10 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return g.db
|
||||||
|
}
|
||||||
|
|
||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
@@ -382,6 +386,12 @@ func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GormSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||||
|
// GORM's Order can handle expressions directly
|
||||||
|
g.db = g.db.Order(gorm.Expr(order, args...))
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
|
func (g *GormSelectQuery) Limit(n int) common.SelectQuery {
|
||||||
g.db = g.db.Limit(n)
|
g.db = g.db.Limit(n)
|
||||||
return g
|
return g
|
||||||
|
|||||||
@@ -137,6 +137,10 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
|||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
// preloadConfig represents a relationship to be preloaded
|
// preloadConfig represents a relationship to be preloaded
|
||||||
type preloadConfig struct {
|
type preloadConfig struct {
|
||||||
relation string
|
relation string
|
||||||
@@ -277,6 +281,13 @@ func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||||
|
// For PgSQL, expressions are passed directly without quoting
|
||||||
|
// If there are args, we would need to format them, but for now just append the expression
|
||||||
|
p.orderBy = append(p.orderBy, order)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
|
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
|
||||||
p.limit = n
|
p.limit = n
|
||||||
return p
|
return p
|
||||||
@@ -897,6 +908,10 @@ func (p *PgSQLTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Da
|
|||||||
return fn(p)
|
return fn(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return p.tx
|
||||||
|
}
|
||||||
|
|
||||||
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
||||||
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
||||||
for _, preload := range p.preloads {
|
for _, preload := range p.preloads {
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ type Database interface {
|
|||||||
CommitTx(ctx context.Context) error
|
CommitTx(ctx context.Context) error
|
||||||
RollbackTx(ctx context.Context) error
|
RollbackTx(ctx context.Context) error
|
||||||
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
||||||
|
|
||||||
|
// GetUnderlyingDB returns the underlying database connection
|
||||||
|
// For GORM, this returns *gorm.DB
|
||||||
|
// For Bun, this returns *bun.DB
|
||||||
|
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||||
|
GetUnderlyingDB() interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||||
@@ -40,6 +46,7 @@ type SelectQuery interface {
|
|||||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||||
Order(order string) SelectQuery
|
Order(order string) SelectQuery
|
||||||
|
OrderExpr(order string, args ...interface{}) SelectQuery
|
||||||
Limit(n int) SelectQuery
|
Limit(n int) SelectQuery
|
||||||
Offset(n int) SelectQuery
|
Offset(n int) SelectQuery
|
||||||
Group(group string) SelectQuery
|
Group(group string) SelectQuery
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -208,6 +209,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Note: We no longer add prefixes to unqualified columns here.
|
||||||
|
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
|
||||||
|
|
||||||
validConditions = append(validConditions, cond)
|
validConditions = append(validConditions, cond)
|
||||||
}
|
}
|
||||||
@@ -483,6 +486,86 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
|
||||||
|
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
|
||||||
|
// "status = 'active'" returns "status"
|
||||||
|
func extractUnqualifiedColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||||
|
|
||||||
|
// Find the column reference (left side of the operator)
|
||||||
|
minIdx := -1
|
||||||
|
for _, op := range operators {
|
||||||
|
idx := strings.Index(cond, op)
|
||||||
|
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||||
|
minIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var columnRef string
|
||||||
|
if minIdx > 0 {
|
||||||
|
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||||
|
} else {
|
||||||
|
// No operator found, might be a single column reference
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef = parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnRef == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any quotes
|
||||||
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
|
// Return empty if it contains a dot (already qualified) or function call
|
||||||
|
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnRef
|
||||||
|
}
|
||||||
|
|
||||||
|
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
|
||||||
|
// Uses word boundaries to avoid partial matches
|
||||||
|
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
|
||||||
|
// returns "table.rid_item is null"
|
||||||
|
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
|
||||||
|
// Use word boundary matching with Go's supported regex syntax
|
||||||
|
// \b matches word boundaries
|
||||||
|
escapedOld := regexp.QuoteMeta(oldRef)
|
||||||
|
pattern := `\b` + escapedOld + `\b`
|
||||||
|
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
// If regex fails, fall back to simple string replacement
|
||||||
|
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
|
||||||
|
return strings.Replace(cond, oldRef, newRef, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
|
||||||
|
result := cond
|
||||||
|
matches := re.FindAllStringIndex(cond, -1)
|
||||||
|
|
||||||
|
// Process matches in reverse order to maintain correct indices
|
||||||
|
for i := len(matches) - 1; i >= 0; i-- {
|
||||||
|
match := matches[i]
|
||||||
|
start := match[0]
|
||||||
|
|
||||||
|
// Check if preceded by a dot (already qualified)
|
||||||
|
if start > 0 && cond[start-1] == '.' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace this occurrence
|
||||||
|
result = result[:start] + newRef + result[match[1]:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||||
@@ -538,3 +621,145 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
|||||||
}
|
}
|
||||||
return validColumns[strings.ToLower(columnName)]
|
return validColumns[strings.ToLower(columnName)]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
|
||||||
|
// This function only prefixes simple column references and skips:
|
||||||
|
// - Columns already having a table prefix (containing a dot)
|
||||||
|
// - Columns inside function calls or expressions (inside parentheses)
|
||||||
|
// - Columns inside subqueries
|
||||||
|
// - Columns that don't exist in the table (validation via model registry)
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
|
||||||
|
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
|
||||||
|
// - "users.status = 'active'" -> unchanged (already has prefix)
|
||||||
|
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
|
||||||
|
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - where: The WHERE clause to process
|
||||||
|
// - tableName: The table name to use as prefix
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The WHERE clause with table prefixes added to appropriate and valid columns
|
||||||
|
func AddTablePrefixToColumns(where string, tableName string) string {
|
||||||
|
if where == "" || tableName == "" {
|
||||||
|
return where
|
||||||
|
}
|
||||||
|
|
||||||
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Get valid columns from the model registry for validation
|
||||||
|
validColumns := getValidColumnsForTable(tableName)
|
||||||
|
|
||||||
|
// Split by AND to handle multiple conditions (parenthesis-aware)
|
||||||
|
conditions := splitByAND(where)
|
||||||
|
prefixedConditions := make([]string, 0, len(conditions))
|
||||||
|
|
||||||
|
for _, cond := range conditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process this condition to add table prefix if appropriate
|
||||||
|
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
|
||||||
|
prefixedConditions = append(prefixedConditions, processedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(prefixedConditions) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(prefixedConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
|
||||||
|
// Returns the condition unchanged if:
|
||||||
|
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
|
||||||
|
// - The column reference is inside a function call
|
||||||
|
// - The column already has a table prefix
|
||||||
|
// - No valid column reference is found
|
||||||
|
// - The column doesn't exist in the table (when validColumns is provided)
|
||||||
|
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
|
||||||
|
// Strip outer grouping parentheses to get to the actual condition
|
||||||
|
strippedCond := stripOuterParentheses(cond)
|
||||||
|
|
||||||
|
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
|
||||||
|
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
|
||||||
|
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the left side of the comparison (before the operator)
|
||||||
|
columnRef := extractLeftSideOfComparison(strippedCond)
|
||||||
|
if columnRef == "" {
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if it already has a prefix (contains a dot)
|
||||||
|
if strings.Contains(columnRef, ".") {
|
||||||
|
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if it's a function call or expression (contains parentheses)
|
||||||
|
if strings.Contains(columnRef, "(") {
|
||||||
|
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that the column exists in the table (if we have column info)
|
||||||
|
if !isValidColumn(columnRef, validColumns) {
|
||||||
|
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's a simple unqualified column reference that exists in the table - add the table prefix
|
||||||
|
newRef := tableName + "." + columnRef
|
||||||
|
result := qualifyColumnInCondition(cond, columnRef, newRef)
|
||||||
|
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
|
||||||
|
// This is used to identify the column reference that may need a table prefix.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "status = 'active'" returns "status"
|
||||||
|
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
|
||||||
|
// - "priority > 5" returns "priority"
|
||||||
|
//
|
||||||
|
// Returns empty string if no operator is found.
|
||||||
|
func extractLeftSideOfComparison(cond string) string {
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||||
|
|
||||||
|
// Find the first operator outside of parentheses and quotes
|
||||||
|
minIdx := -1
|
||||||
|
for _, op := range operators {
|
||||||
|
idx := findOperatorOutsideParentheses(cond, op)
|
||||||
|
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||||
|
minIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if minIdx > 0 {
|
||||||
|
leftSide := strings.TrimSpace(cond[:minIdx])
|
||||||
|
// Remove any surrounding quotes
|
||||||
|
leftSide = strings.Trim(leftSide, "`\"'")
|
||||||
|
return leftSide
|
||||||
|
}
|
||||||
|
|
||||||
|
// No operator found - might be a boolean column
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef := strings.Trim(parts[0], "`\"'")
|
||||||
|
// Make sure it's not a SQL keyword
|
||||||
|
if !IsSQLKeyword(strings.ToLower(columnRef)) {
|
||||||
|
return columnRef
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid condition with parentheses - no prefix added",
|
name: "valid condition with parentheses - prefix added to prevent ambiguity",
|
||||||
where: "(status = 'active')",
|
where: "(status = 'active')",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed trivial and valid conditions - no prefix added",
|
name: "mixed trivial and valid conditions - prefix added",
|
||||||
where: "true AND status = 'active' AND 1=1",
|
where: "true AND status = 'active' AND 1=1",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "condition with correct table prefix - unchanged",
|
name: "condition with correct table prefix - unchanged",
|
||||||
@@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "users.status = 'active' AND users.age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple valid conditions without prefix - no prefix added",
|
name: "multiple valid conditions without prefix - prefixes added",
|
||||||
where: "status = 'active' AND age > 18",
|
where: "status = 'active' AND age > 18",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no table name provided",
|
name: "no table name provided",
|
||||||
@@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
name: "mixed case AND operators",
|
name: "mixed case AND operators",
|
||||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "dangerous DELETE keyword - blocked",
|
name: "dangerous DELETE keyword - blocked",
|
||||||
|
|||||||
@@ -237,6 +237,13 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
for _, sort := range options.Sort {
|
for _, sort := range options.Sort {
|
||||||
if v.IsValidColumn(sort.Column) {
|
if v.IsValidColumn(sort.Column) {
|
||||||
validSorts = append(validSorts, sort)
|
validSorts = append(validSorts, sort)
|
||||||
|
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
|
// Allow sort by expression/subquery, but validate for security
|
||||||
|
if IsSafeSortExpression(sort.Column) {
|
||||||
|
validSorts = append(validSorts, sort)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||||
}
|
}
|
||||||
@@ -262,6 +269,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
}
|
}
|
||||||
filteredPreload.Filters = validPreloadFilters
|
filteredPreload.Filters = validPreloadFilters
|
||||||
|
|
||||||
|
// Filter preload sort columns
|
||||||
|
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||||
|
for _, sort := range preload.Sort {
|
||||||
|
if v.IsValidColumn(sort.Column) {
|
||||||
|
validPreloadSorts = append(validPreloadSorts, sort)
|
||||||
|
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
|
// Allow sort by expression/subquery, but validate for security
|
||||||
|
if IsSafeSortExpression(sort.Column) {
|
||||||
|
validPreloadSorts = append(validPreloadSorts, sort)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Unsafe sort expression in preload '%s' removed: '%s'", preload.Relation, sort.Column)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in preload '%s' sort '%s' removed", preload.Relation, sort.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filteredPreload.Sort = validPreloadSorts
|
||||||
|
|
||||||
validPreloads = append(validPreloads, filteredPreload)
|
validPreloads = append(validPreloads, filteredPreload)
|
||||||
}
|
}
|
||||||
filtered.Preload = validPreloads
|
filtered.Preload = validPreloads
|
||||||
@@ -269,6 +294,56 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSafeSortExpression validates that a sort expression (enclosed in brackets) is safe
|
||||||
|
// and doesn't contain SQL injection attempts or dangerous commands
|
||||||
|
func IsSafeSortExpression(expr string) bool {
|
||||||
|
if expr == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expression must be enclosed in brackets
|
||||||
|
expr = strings.TrimSpace(expr)
|
||||||
|
if !strings.HasPrefix(expr, "(") || !strings.HasSuffix(expr, ")") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove outer brackets for content validation
|
||||||
|
expr = expr[1 : len(expr)-1]
|
||||||
|
expr = strings.TrimSpace(expr)
|
||||||
|
|
||||||
|
// Convert to lowercase for checking dangerous keywords
|
||||||
|
exprLower := strings.ToLower(expr)
|
||||||
|
|
||||||
|
// Check for dangerous SQL commands that should never be in a sort expression
|
||||||
|
dangerousKeywords := []string{
|
||||||
|
"drop ", "delete ", "insert ", "update ", "alter ", "create ",
|
||||||
|
"truncate ", "exec ", "execute ", "grant ", "revoke ",
|
||||||
|
"into ", "values ", "set ", "shutdown", "xp_",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range dangerousKeywords {
|
||||||
|
if strings.Contains(exprLower, keyword) {
|
||||||
|
logger.Warn("Dangerous SQL keyword '%s' detected in sort expression: %s", keyword, expr)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for SQL comment attempts
|
||||||
|
if strings.Contains(expr, "--") || strings.Contains(expr, "/*") || strings.Contains(expr, "*/") {
|
||||||
|
logger.Warn("SQL comment detected in sort expression: %s", expr)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for semicolon (command separator)
|
||||||
|
if strings.Contains(expr, ";") {
|
||||||
|
logger.Warn("Command separator (;) detected in sort expression: %s", expr)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expression appears safe
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// GetValidColumns returns a list of all valid column names for debugging purposes
|
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||||
func (v *ColumnValidator) GetValidColumns() []string {
|
func (v *ColumnValidator) GetValidColumns() []string {
|
||||||
columns := make([]string, 0, len(v.validColumns))
|
columns := make([]string, 0, len(v.validColumns))
|
||||||
|
|||||||
@@ -361,3 +361,83 @@ func TestFilterRequestOptions(t *testing.T) {
|
|||||||
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsSafeSortExpression(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expression string
|
||||||
|
shouldPass bool
|
||||||
|
}{
|
||||||
|
// Safe expressions
|
||||||
|
{"Valid subquery", "(SELECT MAX(price) FROM products)", true},
|
||||||
|
{"Valid CASE expression", "(CASE WHEN status = 'active' THEN 1 ELSE 0 END)", true},
|
||||||
|
{"Valid aggregate", "(COUNT(*) OVER (PARTITION BY category))", true},
|
||||||
|
{"Valid function", "(COALESCE(discount, 0))", true},
|
||||||
|
|
||||||
|
// Dangerous expressions - SQL injection attempts
|
||||||
|
{"DROP TABLE attempt", "(id); DROP TABLE users; --", false},
|
||||||
|
{"DELETE attempt", "(id WHERE 1=1); DELETE FROM users; --", false},
|
||||||
|
{"INSERT attempt", "(id); INSERT INTO admin VALUES ('hacker'); --", false},
|
||||||
|
{"UPDATE attempt", "(id); UPDATE users SET role='admin'; --", false},
|
||||||
|
{"EXEC attempt", "(id); EXEC sp_executesql 'DROP TABLE users'; --", false},
|
||||||
|
{"XP_ stored proc", "(id); xp_cmdshell 'dir'; --", false},
|
||||||
|
|
||||||
|
// Comment injection
|
||||||
|
{"SQL comment dash", "(id) -- malicious comment", false},
|
||||||
|
{"SQL comment block start", "(id) /* comment", false},
|
||||||
|
{"SQL comment block end", "(id) comment */", false},
|
||||||
|
|
||||||
|
// Semicolon attempts
|
||||||
|
{"Semicolon separator", "(id); SELECT * FROM passwords", false},
|
||||||
|
|
||||||
|
// Empty/invalid
|
||||||
|
{"Empty string", "", false},
|
||||||
|
{"Just brackets", "()", true}, // Empty but technically valid structure
|
||||||
|
{"No brackets", "id", false}, // Must have brackets for expressions
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsSafeSortExpression(tt.expression)
|
||||||
|
if result != tt.shouldPass {
|
||||||
|
t.Errorf("IsSafeSortExpression(%q) = %v, want %v", tt.expression, result, tt.shouldPass)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "id", Direction: "ASC"}, // Valid column
|
||||||
|
{Column: "(SELECT MAX(age) FROM users)", Direction: "DESC"}, // Safe expression
|
||||||
|
{Column: "name", Direction: "ASC"}, // Valid column
|
||||||
|
{Column: "(id); DROP TABLE users; --", Direction: "DESC"}, // Dangerous expression
|
||||||
|
{Column: "invalid_col", Direction: "ASC"}, // Invalid column
|
||||||
|
{Column: "(CASE WHEN age > 18 THEN 1 ELSE 0 END)", Direction: "ASC"}, // Safe expression
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// Should keep: id, safe expression, name, another safe expression
|
||||||
|
// Should remove: dangerous expression, invalid column
|
||||||
|
expectedCount := 4
|
||||||
|
if len(filtered.Sort) != expectedCount {
|
||||||
|
t.Errorf("Expected %d sort options, got %d", expectedCount, len(filtered.Sort))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the kept options
|
||||||
|
if filtered.Sort[0].Column != "id" {
|
||||||
|
t.Errorf("Expected first sort to be 'id', got '%s'", filtered.Sort[0].Column)
|
||||||
|
}
|
||||||
|
if filtered.Sort[1].Column != "(SELECT MAX(age) FROM users)" {
|
||||||
|
t.Errorf("Expected second sort to be safe expression, got '%s'", filtered.Sort[1].Column)
|
||||||
|
}
|
||||||
|
if filtered.Sort[2].Column != "name" {
|
||||||
|
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type Config struct {
|
|||||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
|
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServerConfig holds server-related configuration
|
||||||
@@ -91,3 +92,52 @@ type ErrorTrackingConfig struct {
|
|||||||
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
||||||
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EventBrokerConfig contains configuration for the event broker
|
||||||
|
type EventBrokerConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Provider string `mapstructure:"provider"` // memory, redis, nats, database
|
||||||
|
Mode string `mapstructure:"mode"` // sync, async
|
||||||
|
WorkerCount int `mapstructure:"worker_count"`
|
||||||
|
BufferSize int `mapstructure:"buffer_size"`
|
||||||
|
InstanceID string `mapstructure:"instance_id"`
|
||||||
|
Redis EventBrokerRedisConfig `mapstructure:"redis"`
|
||||||
|
NATS EventBrokerNATSConfig `mapstructure:"nats"`
|
||||||
|
Database EventBrokerDatabaseConfig `mapstructure:"database"`
|
||||||
|
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerRedisConfig contains Redis-specific configuration
|
||||||
|
type EventBrokerRedisConfig struct {
|
||||||
|
StreamName string `mapstructure:"stream_name"`
|
||||||
|
ConsumerGroup string `mapstructure:"consumer_group"`
|
||||||
|
MaxLen int64 `mapstructure:"max_len"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
DB int `mapstructure:"db"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerNATSConfig contains NATS-specific configuration
|
||||||
|
type EventBrokerNATSConfig struct {
|
||||||
|
URL string `mapstructure:"url"`
|
||||||
|
StreamName string `mapstructure:"stream_name"`
|
||||||
|
Subjects []string `mapstructure:"subjects"`
|
||||||
|
Storage string `mapstructure:"storage"` // file, memory
|
||||||
|
MaxAge time.Duration `mapstructure:"max_age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerDatabaseConfig contains database provider configuration
|
||||||
|
type EventBrokerDatabaseConfig struct {
|
||||||
|
TableName string `mapstructure:"table_name"`
|
||||||
|
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
|
||||||
|
PollInterval time.Duration `mapstructure:"poll_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerRetryPolicyConfig contains retry policy configuration
|
||||||
|
type EventBrokerRetryPolicyConfig struct {
|
||||||
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
|
InitialDelay time.Duration `mapstructure:"initial_delay"`
|
||||||
|
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||||
|
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -165,4 +165,39 @@ func setDefaults(v *viper.Viper) {
|
|||||||
|
|
||||||
// Database defaults
|
// Database defaults
|
||||||
v.SetDefault("database.url", "")
|
v.SetDefault("database.url", "")
|
||||||
|
|
||||||
|
// Event Broker defaults
|
||||||
|
v.SetDefault("event_broker.enabled", false)
|
||||||
|
v.SetDefault("event_broker.provider", "memory")
|
||||||
|
v.SetDefault("event_broker.mode", "async")
|
||||||
|
v.SetDefault("event_broker.worker_count", 10)
|
||||||
|
v.SetDefault("event_broker.buffer_size", 1000)
|
||||||
|
v.SetDefault("event_broker.instance_id", "")
|
||||||
|
|
||||||
|
// Event Broker - Redis defaults
|
||||||
|
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
|
||||||
|
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
|
||||||
|
v.SetDefault("event_broker.redis.max_len", 10000)
|
||||||
|
v.SetDefault("event_broker.redis.host", "localhost")
|
||||||
|
v.SetDefault("event_broker.redis.port", 6379)
|
||||||
|
v.SetDefault("event_broker.redis.password", "")
|
||||||
|
v.SetDefault("event_broker.redis.db", 0)
|
||||||
|
|
||||||
|
// Event Broker - NATS defaults
|
||||||
|
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
|
||||||
|
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
|
||||||
|
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
|
||||||
|
v.SetDefault("event_broker.nats.storage", "file")
|
||||||
|
v.SetDefault("event_broker.nats.max_age", "24h")
|
||||||
|
|
||||||
|
// Event Broker - Database defaults
|
||||||
|
v.SetDefault("event_broker.database.table_name", "events")
|
||||||
|
v.SetDefault("event_broker.database.channel", "resolvespec_events")
|
||||||
|
v.SetDefault("event_broker.database.poll_interval", "1s")
|
||||||
|
|
||||||
|
// Event Broker - Retry Policy defaults
|
||||||
|
v.SetDefault("event_broker.retry_policy.max_retries", 3)
|
||||||
|
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||||
|
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||||
|
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||||
}
|
}
|
||||||
|
|||||||
327
pkg/eventbroker/README.md
Normal file
327
pkg/eventbroker/README.md
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# Event Broker System
|
||||||
|
|
||||||
|
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
|
||||||
|
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
|
||||||
|
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
|
||||||
|
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
|
||||||
|
- **Pattern Matching**: Subscribe to events using glob-style patterns
|
||||||
|
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
|
||||||
|
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
|
||||||
|
- **Retry Logic**: Configurable retry policy with exponential backoff
|
||||||
|
- **Metrics**: Prometheus-compatible metrics for monitoring
|
||||||
|
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configuration
|
||||||
|
|
||||||
|
Add to your `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
enabled: true
|
||||||
|
provider: memory # memory, redis, nats, database
|
||||||
|
mode: async # sync, async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
instance_id: "${HOSTNAME}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Initialize
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
cfgMgr := config.NewManager()
|
||||||
|
cfg, _ := cfgMgr.GetConfig()
|
||||||
|
|
||||||
|
// Initialize event broker
|
||||||
|
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Subscribe to Events
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Subscribe to specific events
|
||||||
|
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
log.Printf("New user created: %s", event.Payload)
|
||||||
|
// Send welcome email, update cache, etc.
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Subscribe with patterns
|
||||||
|
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Publish Events
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Create and publish an event
|
||||||
|
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
|
||||||
|
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
|
||||||
|
event.UserID = 123
|
||||||
|
event.SessionID = "session-456"
|
||||||
|
event.Schema = "public"
|
||||||
|
event.Entity = "users"
|
||||||
|
event.Operation = "update"
|
||||||
|
|
||||||
|
event.SetPayload(map[string]interface{}{
|
||||||
|
"id": 123,
|
||||||
|
"name": "John Doe",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Async (non-blocking)
|
||||||
|
eventbroker.PublishAsync(ctx, event)
|
||||||
|
|
||||||
|
// Sync (blocking)
|
||||||
|
eventbroker.PublishSync(ctx, event)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Automatic CRUD Event Capture
|
||||||
|
|
||||||
|
Automatically capture database CRUD operations:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupHooks(handler *restheadspec.Handler) {
|
||||||
|
broker := eventbroker.GetDefaultBroker()
|
||||||
|
|
||||||
|
// Configure which operations to capture
|
||||||
|
config := eventbroker.DefaultCRUDHookConfig()
|
||||||
|
config.EnableRead = false // Disable read events for performance
|
||||||
|
|
||||||
|
// Register hooks
|
||||||
|
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
|
||||||
|
|
||||||
|
// Now all create/update/delete operations automatically publish events!
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Event Structure
|
||||||
|
|
||||||
|
Every event contains:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Event struct {
|
||||||
|
ID string // UUID
|
||||||
|
Source EventSource // database, websocket, system, frontend, internal
|
||||||
|
Type string // Pattern: schema.entity.operation
|
||||||
|
Status EventStatus // pending, processing, completed, failed
|
||||||
|
Payload json.RawMessage // JSON payload
|
||||||
|
UserID int // User who triggered the event
|
||||||
|
SessionID string // Session identifier
|
||||||
|
InstanceID string // Server instance identifier
|
||||||
|
Schema string // Database schema
|
||||||
|
Entity string // Database entity/table
|
||||||
|
Operation string // create, update, delete, read
|
||||||
|
CreatedAt time.Time // When event was created
|
||||||
|
ProcessedAt *time.Time // When processing started
|
||||||
|
CompletedAt *time.Time // When processing completed
|
||||||
|
Error string // Error message if failed
|
||||||
|
Metadata map[string]interface{} // Additional context
|
||||||
|
RetryCount int // Number of retry attempts
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern Matching
|
||||||
|
|
||||||
|
Subscribe to events using glob-style patterns:
|
||||||
|
|
||||||
|
| Pattern | Matches | Example |
|
||||||
|
|---------|---------|---------|
|
||||||
|
| `*` | All events | Any event |
|
||||||
|
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
|
||||||
|
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
|
||||||
|
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
|
||||||
|
| `public.users.create` | Exact match | Only `public.users.create` |
|
||||||
|
|
||||||
|
## Providers
|
||||||
|
|
||||||
|
### Memory Provider (Default)
|
||||||
|
|
||||||
|
Best for: Development, single-instance deployments
|
||||||
|
|
||||||
|
- **Pros**: Fast, no dependencies, simple
|
||||||
|
- **Cons**: Events lost on restart, single-instance only
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: memory
|
||||||
|
```
|
||||||
|
|
||||||
|
### Redis Provider (Future)
|
||||||
|
|
||||||
|
Best for: Production, multi-instance deployments
|
||||||
|
|
||||||
|
- **Pros**: Persistent, cross-instance pub/sub, reliable
|
||||||
|
- **Cons**: Requires Redis
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: redis
|
||||||
|
redis:
|
||||||
|
stream_name: "resolvespec:events"
|
||||||
|
consumer_group: "resolvespec-workers"
|
||||||
|
host: "localhost"
|
||||||
|
port: 6379
|
||||||
|
```
|
||||||
|
|
||||||
|
### NATS Provider (Future)
|
||||||
|
|
||||||
|
Best for: High-performance, low-latency requirements
|
||||||
|
|
||||||
|
- **Pros**: Very fast, built-in clustering, durable
|
||||||
|
- **Cons**: Requires NATS server
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: nats
|
||||||
|
nats:
|
||||||
|
url: "nats://localhost:4222"
|
||||||
|
stream_name: "RESOLVESPEC_EVENTS"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Provider (Future)
|
||||||
|
|
||||||
|
Best for: Audit trails, event replay, SQL queries
|
||||||
|
|
||||||
|
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
|
||||||
|
- **Cons**: Slower than Redis/NATS
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: database
|
||||||
|
database:
|
||||||
|
table_name: "events"
|
||||||
|
channel: "resolvespec_events"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Processing Modes
|
||||||
|
|
||||||
|
### Async Mode (Recommended)
|
||||||
|
|
||||||
|
Events are queued and processed by worker pool:
|
||||||
|
|
||||||
|
- Non-blocking event publishing
|
||||||
|
- Configurable worker count
|
||||||
|
- Better throughput
|
||||||
|
- Events may be processed out of order
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
mode: async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sync Mode
|
||||||
|
|
||||||
|
Events are processed immediately:
|
||||||
|
|
||||||
|
- Blocking event publishing
|
||||||
|
- Guaranteed ordering
|
||||||
|
- Immediate error feedback
|
||||||
|
- Lower throughput
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
mode: sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## Retry Policy
|
||||||
|
|
||||||
|
Configure automatic retries for failed handlers:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
retry_policy:
|
||||||
|
max_retries: 3
|
||||||
|
initial_delay: 1s
|
||||||
|
max_delay: 30s
|
||||||
|
backoff_factor: 2.0 # Exponential backoff
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
The event broker exposes Prometheus metrics:
|
||||||
|
|
||||||
|
- `eventbroker_events_published_total{source, type}` - Total events published
|
||||||
|
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
|
||||||
|
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
|
||||||
|
- `eventbroker_queue_size` - Current queue size (async mode)
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use Async Mode**: For better performance, use async mode in production
|
||||||
|
2. **Disable Read Events**: Read events can be high volume; disable if not needed
|
||||||
|
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
|
||||||
|
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
|
||||||
|
5. **Idempotency**: Make handlers idempotent as events may be retried
|
||||||
|
6. **Payload Size**: Keep payloads reasonable; avoid large objects
|
||||||
|
7. **Monitoring**: Monitor metrics to detect issues early
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
See `example_usage.go` for comprehensive examples including:
|
||||||
|
- Basic event publishing and subscription
|
||||||
|
- Hook integration
|
||||||
|
- Error handling
|
||||||
|
- Configuration
|
||||||
|
- Pattern matching
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Application │
|
||||||
|
└────────┬────────┘
|
||||||
|
│
|
||||||
|
├─ Publish Events
|
||||||
|
│
|
||||||
|
┌────────▼────────┐ ┌──────────────┐
|
||||||
|
│ Event Broker │◄────►│ Subscribers │
|
||||||
|
└────────┬────────┘ └──────────────┘
|
||||||
|
│
|
||||||
|
├─ Store Events
|
||||||
|
│
|
||||||
|
┌────────▼────────┐
|
||||||
|
│ Provider │
|
||||||
|
│ (Memory/Redis │
|
||||||
|
│ /NATS/DB) │
|
||||||
|
└─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
- [ ] Database Provider with PostgreSQL NOTIFY
|
||||||
|
- [ ] Redis Streams Provider
|
||||||
|
- [ ] NATS JetStream Provider
|
||||||
|
- [ ] Event replay functionality
|
||||||
|
- [ ] Dead letter queue
|
||||||
|
- [ ] Event filtering at provider level
|
||||||
|
- [ ] Batch publishing
|
||||||
|
- [ ] Event compression
|
||||||
|
- [ ] Schema versioning
|
||||||
453
pkg/eventbroker/broker.go
Normal file
453
pkg/eventbroker/broker.go
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Broker is the main interface for event publishing and subscription
|
||||||
|
type Broker interface {
|
||||||
|
// Publish publishes an event (mode-dependent: sync or async)
|
||||||
|
Publish(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously (blocks until all handlers complete)
|
||||||
|
PublishSync(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously (returns immediately)
|
||||||
|
PublishAsync(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Subscribe registers a handler for events matching the pattern
|
||||||
|
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
Unsubscribe(id SubscriptionID) error
|
||||||
|
|
||||||
|
// Start starts the broker (begins processing events)
|
||||||
|
Start(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stop stops the broker gracefully (flushes pending events)
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stats returns broker statistics
|
||||||
|
Stats(ctx context.Context) (*BrokerStats, error)
|
||||||
|
|
||||||
|
// InstanceID returns the instance ID of this broker
|
||||||
|
InstanceID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessingMode determines how events are processed
|
||||||
|
type ProcessingMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProcessingModeSync ProcessingMode = "sync"
|
||||||
|
ProcessingModeAsync ProcessingMode = "async"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BrokerStats contains broker statistics
|
||||||
|
type BrokerStats struct {
|
||||||
|
InstanceID string `json:"instance_id"`
|
||||||
|
Mode ProcessingMode `json:"mode"`
|
||||||
|
IsRunning bool `json:"is_running"`
|
||||||
|
TotalPublished int64 `json:"total_published"`
|
||||||
|
TotalProcessed int64 `json:"total_processed"`
|
||||||
|
TotalFailed int64 `json:"total_failed"`
|
||||||
|
ActiveSubscribers int `json:"active_subscribers"`
|
||||||
|
QueueSize int `json:"queue_size,omitempty"` // For async mode
|
||||||
|
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
|
||||||
|
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
|
||||||
|
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBroker implements the Broker interface
|
||||||
|
type EventBroker struct {
|
||||||
|
provider Provider
|
||||||
|
subscriptions *subscriptionManager
|
||||||
|
mode ProcessingMode
|
||||||
|
instanceID string
|
||||||
|
retryPolicy *RetryPolicy
|
||||||
|
|
||||||
|
// Async mode fields (initialized in Phase 4)
|
||||||
|
workerPool *workerPool
|
||||||
|
|
||||||
|
// Runtime state
|
||||||
|
isRunning atomic.Bool
|
||||||
|
stopOnce sync.Once
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
statsPublished atomic.Int64
|
||||||
|
statsProcessed atomic.Int64
|
||||||
|
statsFailed atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryPolicy defines how failed events should be retried
|
||||||
|
type RetryPolicy struct {
|
||||||
|
MaxRetries int
|
||||||
|
InitialDelay time.Duration
|
||||||
|
MaxDelay time.Duration
|
||||||
|
BackoffFactor float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRetryPolicy returns a sensible default retry policy
|
||||||
|
func DefaultRetryPolicy() *RetryPolicy {
|
||||||
|
return &RetryPolicy{
|
||||||
|
MaxRetries: 3,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 30 * time.Second,
|
||||||
|
BackoffFactor: 2.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options for creating a new broker
|
||||||
|
type Options struct {
|
||||||
|
Provider Provider
|
||||||
|
Mode ProcessingMode
|
||||||
|
WorkerCount int // For async mode
|
||||||
|
BufferSize int // For async mode
|
||||||
|
RetryPolicy *RetryPolicy
|
||||||
|
InstanceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBroker creates a new event broker with the given options
|
||||||
|
func NewBroker(opts Options) (*EventBroker, error) {
|
||||||
|
if opts.Provider == nil {
|
||||||
|
return nil, fmt.Errorf("provider is required")
|
||||||
|
}
|
||||||
|
if opts.InstanceID == "" {
|
||||||
|
return nil, fmt.Errorf("instance ID is required")
|
||||||
|
}
|
||||||
|
if opts.Mode == "" {
|
||||||
|
opts.Mode = ProcessingModeAsync // Default to async
|
||||||
|
}
|
||||||
|
if opts.RetryPolicy == nil {
|
||||||
|
opts.RetryPolicy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
broker := &EventBroker{
|
||||||
|
provider: opts.Provider,
|
||||||
|
subscriptions: newSubscriptionManager(),
|
||||||
|
mode: opts.Mode,
|
||||||
|
instanceID: opts.InstanceID,
|
||||||
|
retryPolicy: opts.RetryPolicy,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Worker pool will be initialized in Phase 4 for async mode
|
||||||
|
if opts.Mode == ProcessingModeAsync {
|
||||||
|
if opts.WorkerCount == 0 {
|
||||||
|
opts.WorkerCount = 10 // Default
|
||||||
|
}
|
||||||
|
if opts.BufferSize == 0 {
|
||||||
|
opts.BufferSize = 1000 // Default
|
||||||
|
}
|
||||||
|
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return broker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Functional option pattern helpers
|
||||||
|
func WithProvider(p Provider) func(*Options) {
|
||||||
|
return func(o *Options) { o.Provider = p }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithMode(m ProcessingMode) func(*Options) {
|
||||||
|
return func(o *Options) { o.Mode = m }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithWorkerCount(count int) func(*Options) {
|
||||||
|
return func(o *Options) { o.WorkerCount = count }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBufferSize(size int) func(*Options) {
|
||||||
|
return func(o *Options) { o.BufferSize = size }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
|
||||||
|
return func(o *Options) { o.RetryPolicy = policy }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithInstanceID(id string) func(*Options) {
|
||||||
|
return func(o *Options) { o.InstanceID = id }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the broker
|
||||||
|
func (b *EventBroker) Start(ctx context.Context) error {
|
||||||
|
if b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start worker pool for async mode
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
b.workerPool.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the broker gracefully
|
||||||
|
func (b *EventBroker) Stop(ctx context.Context) error {
|
||||||
|
var stopErr error
|
||||||
|
|
||||||
|
b.stopOnce.Do(func() {
|
||||||
|
logger.Info("Stopping event broker...")
|
||||||
|
|
||||||
|
// Mark as not running
|
||||||
|
b.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Close the stop channel
|
||||||
|
close(b.stopCh)
|
||||||
|
|
||||||
|
// Stop worker pool for async mode
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
if err := b.workerPool.Stop(ctx); err != nil {
|
||||||
|
logger.Error("Error stopping worker pool: %v", err)
|
||||||
|
stopErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
b.wg.Wait()
|
||||||
|
|
||||||
|
// Close provider
|
||||||
|
if err := b.provider.Close(); err != nil {
|
||||||
|
logger.Error("Error closing provider: %v", err)
|
||||||
|
if stopErr == nil {
|
||||||
|
stopErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Event broker stopped")
|
||||||
|
})
|
||||||
|
|
||||||
|
return stopErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event based on the broker's mode
|
||||||
|
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
|
||||||
|
if b.mode == ProcessingModeSync {
|
||||||
|
return b.PublishSync(ctx, event)
|
||||||
|
}
|
||||||
|
return b.PublishAsync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously
|
||||||
|
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
|
||||||
|
if !b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate event
|
||||||
|
if err := event.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event in provider
|
||||||
|
if err := b.provider.Publish(ctx, event); err != nil {
|
||||||
|
return fmt.Errorf("failed to publish event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsPublished.Add(1)
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventPublished(event)
|
||||||
|
|
||||||
|
// Process event synchronously
|
||||||
|
if err := b.processEvent(ctx, event); err != nil {
|
||||||
|
logger.Error("Failed to process event %s: %v", event.ID, err)
|
||||||
|
b.statsFailed.Add(1)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsProcessed.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously
|
||||||
|
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
|
||||||
|
if !b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate event
|
||||||
|
if err := event.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event in provider
|
||||||
|
if err := b.provider.Publish(ctx, event); err != nil {
|
||||||
|
return fmt.Errorf("failed to publish event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsPublished.Add(1)
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventPublished(event)
|
||||||
|
|
||||||
|
// Queue for async processing
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
// Update queue size metrics
|
||||||
|
updateQueueSize(int64(b.workerPool.QueueSize()))
|
||||||
|
return b.workerPool.Submit(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to sync if async not configured
|
||||||
|
return b.processEvent(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a subscription for events matching the pattern
|
||||||
|
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
return b.subscriptions.Subscribe(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
|
||||||
|
return b.subscriptions.Unsubscribe(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processEvent processes an event by calling all matching handlers
|
||||||
|
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Get all handlers matching this event type
|
||||||
|
handlers := b.subscriptions.GetMatching(event.Type)
|
||||||
|
|
||||||
|
if len(handlers) == 0 {
|
||||||
|
logger.Debug("No handlers for event type: %s", event.Type)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
|
||||||
|
|
||||||
|
// Mark event as processing
|
||||||
|
event.MarkProcessing()
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
var lastErr error
|
||||||
|
for i, handler := range handlers {
|
||||||
|
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
|
||||||
|
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
|
||||||
|
lastErr = err
|
||||||
|
// Continue processing other handlers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update final status
|
||||||
|
if lastErr != nil {
|
||||||
|
event.MarkFailed(lastErr)
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventProcessed(event, time.Since(startTime))
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
event.MarkCompleted()
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventProcessed(event, time.Since(startTime))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeHandlerWithRetry executes a handler with retry logic
|
||||||
|
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
// Calculate backoff delay
|
||||||
|
delay := b.calculateBackoff(attempt)
|
||||||
|
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
|
||||||
|
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
event.IncrementRetry()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute handler
|
||||||
|
if err := handler.Handle(ctx, event); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff calculates the backoff delay for a retry attempt
|
||||||
|
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
|
||||||
|
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
|
||||||
|
if delay > float64(b.retryPolicy.MaxDelay) {
|
||||||
|
delay = float64(b.retryPolicy.MaxDelay)
|
||||||
|
}
|
||||||
|
return time.Duration(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pow is a simple integer power function
|
||||||
|
func pow(base float64, exp float64) float64 {
|
||||||
|
result := 1.0
|
||||||
|
for i := 0.0; i < exp; i++ {
|
||||||
|
result *= base
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns broker statistics
|
||||||
|
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
|
||||||
|
providerStats, err := b.provider.Stats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to get provider stats: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &BrokerStats{
|
||||||
|
InstanceID: b.instanceID,
|
||||||
|
Mode: b.mode,
|
||||||
|
IsRunning: b.isRunning.Load(),
|
||||||
|
TotalPublished: b.statsPublished.Load(),
|
||||||
|
TotalProcessed: b.statsProcessed.Load(),
|
||||||
|
TotalFailed: b.statsFailed.Load(),
|
||||||
|
ActiveSubscribers: b.subscriptions.Count(),
|
||||||
|
ProviderStats: providerStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add async-specific stats
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
stats.QueueSize = b.workerPool.QueueSize()
|
||||||
|
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstanceID returns the instance ID
|
||||||
|
func (b *EventBroker) InstanceID() string {
|
||||||
|
return b.instanceID
|
||||||
|
}
|
||||||
524
pkg/eventbroker/broker_test.go
Normal file
524
pkg/eventbroker/broker_test.go
Normal file
@@ -0,0 +1,524 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBroker(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts Options
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid options",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing provider",
|
||||||
|
opts: Options{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing instance ID",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "async mode with defaults",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
broker, err := NewBroker(tt.opts)
|
||||||
|
if (err != nil) != tt.wantError {
|
||||||
|
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
|
||||||
|
}
|
||||||
|
if err == nil && broker == nil {
|
||||||
|
t.Error("Expected non-nil broker")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerStartStop(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, err := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Start
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Failed to start broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test double start (should fail)
|
||||||
|
if err := broker.Start(context.Background()); err == nil {
|
||||||
|
t.Error("Expected error on double start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Stop
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Failed to stop broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test double stop (should not fail)
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Error("Double stop should not fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishSync(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe to events
|
||||||
|
called := false
|
||||||
|
var receivedEvent *Event
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
receivedEvent = event
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.PublishSync(context.Background(), event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PublishSync failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler was called
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
if receivedEvent == nil || receivedEvent.ID != event.ID {
|
||||||
|
t.Error("Expected to receive the published event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify event status
|
||||||
|
if event.Status != EventStatusCompleted {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishAsync(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 2,
|
||||||
|
BufferSize: 10,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe to events
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
if err := broker.PublishAsync(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("PublishAsync failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for events to be processed
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
if callCount.Load() != 5 {
|
||||||
|
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishBeforeStart(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.Publish(context.Background(), event)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when publishing before start")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerHandlerError(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
RetryPolicy: &RetryPolicy{
|
||||||
|
MaxRetries: 2,
|
||||||
|
InitialDelay: 10 * time.Millisecond,
|
||||||
|
MaxDelay: 100 * time.Millisecond,
|
||||||
|
BackoffFactor: 2.0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe with failing handler
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return errors.New("handler error")
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// Should fail after retries
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have been called MaxRetries+1 times (initial + retries)
|
||||||
|
if callCount.Load() != 3 {
|
||||||
|
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event should be marked as failed
|
||||||
|
if event.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerMultipleHandlers(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe multiple handlers
|
||||||
|
var called1, called2, called3 bool
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called1 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called2 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called3 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// All handlers should be called
|
||||||
|
if !called1 || !called2 || !called3 {
|
||||||
|
t.Error("Expected all handlers to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerUnsubscribe(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
called := false
|
||||||
|
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Unsubscribe
|
||||||
|
if err := broker.Unsubscribe(id); err != nil {
|
||||||
|
t.Fatalf("Unsubscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// Handler should not be called
|
||||||
|
if called {
|
||||||
|
t.Error("Expected handler not to be called after unsubscribe")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerStats(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get stats
|
||||||
|
stats, err := broker.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.InstanceID != "test-instance" {
|
||||||
|
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
|
||||||
|
}
|
||||||
|
if stats.TotalPublished != 3 {
|
||||||
|
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
|
||||||
|
}
|
||||||
|
if stats.TotalProcessed != 3 {
|
||||||
|
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
|
||||||
|
}
|
||||||
|
if stats.ActiveSubscribers != 1 {
|
||||||
|
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerInstanceID(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "my-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
if broker.InstanceID() != "my-instance" {
|
||||||
|
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerConcurrentPublish(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 5,
|
||||||
|
BufferSize: 100,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish concurrently
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishAsync(context.Background(), event)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
time.Sleep(200 * time.Millisecond) // Wait for async processing
|
||||||
|
|
||||||
|
if callCount.Load() != 50 {
|
||||||
|
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerGracefulShutdown(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 2,
|
||||||
|
BufferSize: 10,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
|
||||||
|
var processedCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||||
|
processedCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishAsync(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop broker (should wait for events to be processed)
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All events should be processed
|
||||||
|
if processedCount.Load() != 5 {
|
||||||
|
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerDefaultRetryPolicy(t *testing.T) {
|
||||||
|
policy := DefaultRetryPolicy()
|
||||||
|
|
||||||
|
if policy.MaxRetries != 3 {
|
||||||
|
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
|
||||||
|
}
|
||||||
|
if policy.InitialDelay != 1*time.Second {
|
||||||
|
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
|
||||||
|
}
|
||||||
|
if policy.BackoffFactor != 2.0 {
|
||||||
|
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerProcessingModes(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mode ProcessingMode
|
||||||
|
}{
|
||||||
|
{"sync mode", ProcessingModeSync},
|
||||||
|
{"async mode", ProcessingModeAsync},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: tt.mode,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
called := false
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
if tt.mode == ProcessingModeAsync {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
175
pkg/eventbroker/event.go
Normal file
175
pkg/eventbroker/event.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventSource represents where an event originated from
|
||||||
|
type EventSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventSourceDatabase EventSource = "database"
|
||||||
|
EventSourceWebSocket EventSource = "websocket"
|
||||||
|
EventSourceFrontend EventSource = "frontend"
|
||||||
|
EventSourceSystem EventSource = "system"
|
||||||
|
EventSourceInternal EventSource = "internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventStatus represents the current state of an event
|
||||||
|
type EventStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventStatusPending EventStatus = "pending"
|
||||||
|
EventStatusProcessing EventStatus = "processing"
|
||||||
|
EventStatusCompleted EventStatus = "completed"
|
||||||
|
EventStatusFailed EventStatus = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents a single event in the system with complete metadata
|
||||||
|
type Event struct {
|
||||||
|
// Identification
|
||||||
|
ID string `json:"id" db:"id"`
|
||||||
|
|
||||||
|
// Source & Classification
|
||||||
|
Source EventSource `json:"source" db:"source"`
|
||||||
|
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
|
||||||
|
|
||||||
|
// Status Tracking
|
||||||
|
Status EventStatus `json:"status" db:"status"`
|
||||||
|
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||||
|
Error string `json:"error,omitempty" db:"error"`
|
||||||
|
|
||||||
|
// Payload
|
||||||
|
Payload json.RawMessage `json:"payload" db:"payload"`
|
||||||
|
|
||||||
|
// Context Information
|
||||||
|
UserID int `json:"user_id" db:"user_id"`
|
||||||
|
SessionID string `json:"session_id" db:"session_id"`
|
||||||
|
InstanceID string `json:"instance_id" db:"instance_id"`
|
||||||
|
|
||||||
|
// Database Context
|
||||||
|
Schema string `json:"schema" db:"schema"`
|
||||||
|
Entity string `json:"entity" db:"entity"`
|
||||||
|
Operation string `json:"operation" db:"operation"` // create, update, delete, read
|
||||||
|
|
||||||
|
// Timestamps
|
||||||
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||||
|
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
|
||||||
|
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||||
|
|
||||||
|
// Extensibility
|
||||||
|
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEvent creates a new event with defaults
|
||||||
|
func NewEvent(source EventSource, eventType string) *Event {
|
||||||
|
return &Event{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Source: source,
|
||||||
|
Type: eventType,
|
||||||
|
Status: EventStatusPending,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Metadata: make(map[string]interface{}),
|
||||||
|
RetryCount: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventType generates a type string from schema, entity, and operation
|
||||||
|
// Pattern: schema.entity.operation (e.g., "public.users.create")
|
||||||
|
func EventType(schema, entity, operation string) string {
|
||||||
|
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkProcessing marks the event as being processed
|
||||||
|
func (e *Event) MarkProcessing() {
|
||||||
|
e.Status = EventStatusProcessing
|
||||||
|
now := time.Now()
|
||||||
|
e.ProcessedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkCompleted marks the event as successfully completed
|
||||||
|
func (e *Event) MarkCompleted() {
|
||||||
|
e.Status = EventStatusCompleted
|
||||||
|
now := time.Now()
|
||||||
|
e.CompletedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkFailed marks the event as failed with an error message
|
||||||
|
func (e *Event) MarkFailed(err error) {
|
||||||
|
e.Status = EventStatusFailed
|
||||||
|
e.Error = err.Error()
|
||||||
|
now := time.Now()
|
||||||
|
e.CompletedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementRetry increments the retry counter
|
||||||
|
func (e *Event) IncrementRetry() {
|
||||||
|
e.RetryCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPayload sets the event payload from any value by marshaling to JSON
|
||||||
|
func (e *Event) SetPayload(v interface{}) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
e.Payload = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPayload unmarshals the payload into the provided value
|
||||||
|
func (e *Event) GetPayload(v interface{}) error {
|
||||||
|
if len(e.Payload) == 0 {
|
||||||
|
return fmt.Errorf("payload is empty")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(e.Payload, v); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal payload: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone creates a deep copy of the event
|
||||||
|
func (e *Event) Clone() *Event {
|
||||||
|
clone := *e
|
||||||
|
|
||||||
|
// Deep copy metadata
|
||||||
|
if e.Metadata != nil {
|
||||||
|
clone.Metadata = make(map[string]interface{})
|
||||||
|
for k, v := range e.Metadata {
|
||||||
|
clone.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deep copy timestamps
|
||||||
|
if e.ProcessedAt != nil {
|
||||||
|
t := *e.ProcessedAt
|
||||||
|
clone.ProcessedAt = &t
|
||||||
|
}
|
||||||
|
if e.CompletedAt != nil {
|
||||||
|
t := *e.CompletedAt
|
||||||
|
clone.CompletedAt = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate performs basic validation on the event
|
||||||
|
func (e *Event) Validate() error {
|
||||||
|
if e.ID == "" {
|
||||||
|
return fmt.Errorf("event ID is required")
|
||||||
|
}
|
||||||
|
if e.Source == "" {
|
||||||
|
return fmt.Errorf("event source is required")
|
||||||
|
}
|
||||||
|
if e.Type == "" {
|
||||||
|
return fmt.Errorf("event type is required")
|
||||||
|
}
|
||||||
|
if e.InstanceID == "" {
|
||||||
|
return fmt.Errorf("instance ID is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
314
pkg/eventbroker/event_test.go
Normal file
314
pkg/eventbroker/event_test.go
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewEvent(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
if event.ID == "" {
|
||||||
|
t.Error("Expected event ID to be generated")
|
||||||
|
}
|
||||||
|
if event.Source != EventSourceDatabase {
|
||||||
|
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
|
||||||
|
}
|
||||||
|
if event.Type != "public.users.create" {
|
||||||
|
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
|
||||||
|
}
|
||||||
|
if event.Status != EventStatusPending {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
|
||||||
|
}
|
||||||
|
if event.CreatedAt.IsZero() {
|
||||||
|
t.Error("Expected CreatedAt to be set")
|
||||||
|
}
|
||||||
|
if event.Metadata == nil {
|
||||||
|
t.Error("Expected Metadata to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
schema string
|
||||||
|
entity string
|
||||||
|
operation string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"public", "users", "create", "public.users.create"},
|
||||||
|
{"admin", "roles", "update", "admin.roles.update"},
|
||||||
|
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := EventType(tt.schema, tt.entity, tt.operation)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
|
||||||
|
tt.schema, tt.entity, tt.operation, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventValidate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
event *Event
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid event",
|
||||||
|
event: func() *Event {
|
||||||
|
e := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
e.InstanceID = "test-instance"
|
||||||
|
return e
|
||||||
|
}(),
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing ID",
|
||||||
|
event: &Event{
|
||||||
|
Source: EventSourceDatabase,
|
||||||
|
Type: "public.users.create",
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing source",
|
||||||
|
event: &Event{
|
||||||
|
ID: "test-id",
|
||||||
|
Type: "public.users.create",
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing type",
|
||||||
|
event: &Event{
|
||||||
|
ID: "test-id",
|
||||||
|
Source: EventSourceDatabase,
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.event.Validate()
|
||||||
|
if (err != nil) != tt.wantError {
|
||||||
|
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventSetPayload(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := event.SetPayload(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Payload == nil {
|
||||||
|
t.Fatal("Expected payload to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify payload can be unmarshaled
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(event.Payload, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventGetPayload(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"id": float64(1), // JSON unmarshals numbers as float64
|
||||||
|
"name": "John Doe",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := event.SetPayload(payload); err != nil {
|
||||||
|
t.Fatalf("SetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := event.GetPayload(&result); err != nil {
|
||||||
|
t.Fatalf("GetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkProcessing(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.MarkProcessing()
|
||||||
|
|
||||||
|
if event.Status != EventStatusProcessing {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
|
||||||
|
}
|
||||||
|
if event.ProcessedAt == nil {
|
||||||
|
t.Error("Expected ProcessedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkCompleted(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.MarkCompleted()
|
||||||
|
|
||||||
|
if event.Status != EventStatusCompleted {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||||
|
}
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Error("Expected CompletedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkFailed(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
event.MarkFailed(testErr)
|
||||||
|
|
||||||
|
if event.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||||
|
}
|
||||||
|
if event.Error != "test error" {
|
||||||
|
t.Errorf("Expected error %q, got %q", "test error", event.Error)
|
||||||
|
}
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Error("Expected CompletedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventIncrementRetry(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
initialCount := event.RetryCount
|
||||||
|
event.IncrementRetry()
|
||||||
|
|
||||||
|
if event.RetryCount != initialCount+1 {
|
||||||
|
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventJSONMarshaling(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.UserID = 123
|
||||||
|
event.SessionID = "session-123"
|
||||||
|
event.InstanceID = "instance-1"
|
||||||
|
event.Schema = "public"
|
||||||
|
event.Entity = "users"
|
||||||
|
event.Operation = "create"
|
||||||
|
event.SetPayload(map[string]interface{}{"name": "Test"})
|
||||||
|
|
||||||
|
// Marshal to JSON
|
||||||
|
data, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal back
|
||||||
|
var decoded Event
|
||||||
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields
|
||||||
|
if decoded.ID != event.ID {
|
||||||
|
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
|
||||||
|
}
|
||||||
|
if decoded.Source != event.Source {
|
||||||
|
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
|
||||||
|
}
|
||||||
|
if decoded.UserID != event.UserID {
|
||||||
|
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventStatusString(t *testing.T) {
|
||||||
|
statuses := []EventStatus{
|
||||||
|
EventStatusPending,
|
||||||
|
EventStatusProcessing,
|
||||||
|
EventStatusCompleted,
|
||||||
|
EventStatusFailed,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, status := range statuses {
|
||||||
|
if string(status) == "" {
|
||||||
|
t.Errorf("EventStatus %v has empty string representation", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventSourceString(t *testing.T) {
|
||||||
|
sources := []EventSource{
|
||||||
|
EventSourceDatabase,
|
||||||
|
EventSourceWebSocket,
|
||||||
|
EventSourceFrontend,
|
||||||
|
EventSourceSystem,
|
||||||
|
EventSourceInternal,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, source := range sources {
|
||||||
|
if string(source) == "" {
|
||||||
|
t.Errorf("EventSource %v has empty string representation", source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMetadata(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
// Test setting metadata
|
||||||
|
event.Metadata["key1"] = "value1"
|
||||||
|
event.Metadata["key2"] = 123
|
||||||
|
|
||||||
|
if event.Metadata["key1"] != "value1" {
|
||||||
|
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
|
||||||
|
}
|
||||||
|
if event.Metadata["key2"] != 123 {
|
||||||
|
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventTimestamps(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
createdAt := event.CreatedAt
|
||||||
|
|
||||||
|
// Wait a tiny bit to ensure timestamps differ
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
event.MarkProcessing()
|
||||||
|
if event.ProcessedAt == nil {
|
||||||
|
t.Fatal("ProcessedAt should be set")
|
||||||
|
}
|
||||||
|
if !event.ProcessedAt.After(createdAt) {
|
||||||
|
t.Error("ProcessedAt should be after CreatedAt")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
event.MarkCompleted()
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Fatal("CompletedAt should be set")
|
||||||
|
}
|
||||||
|
if !event.CompletedAt.After(*event.ProcessedAt) {
|
||||||
|
t.Error("CompletedAt should be after ProcessedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
160
pkg/eventbroker/eventbroker.go
Normal file
160
pkg/eventbroker/eventbroker.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultBroker Broker
|
||||||
|
brokerMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize initializes the global event broker from configuration
|
||||||
|
func Initialize(cfg config.EventBrokerConfig) error {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
logger.Info("Event broker is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
provider, err := NewProviderFromConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse mode
|
||||||
|
mode := ProcessingModeAsync
|
||||||
|
if cfg.Mode == "sync" {
|
||||||
|
mode = ProcessingModeSync
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert retry policy
|
||||||
|
retryPolicy := &RetryPolicy{
|
||||||
|
MaxRetries: cfg.RetryPolicy.MaxRetries,
|
||||||
|
InitialDelay: cfg.RetryPolicy.InitialDelay,
|
||||||
|
MaxDelay: cfg.RetryPolicy.MaxDelay,
|
||||||
|
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
|
||||||
|
}
|
||||||
|
if retryPolicy.MaxRetries == 0 {
|
||||||
|
retryPolicy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create broker options
|
||||||
|
opts := Options{
|
||||||
|
Provider: provider,
|
||||||
|
Mode: mode,
|
||||||
|
WorkerCount: cfg.WorkerCount,
|
||||||
|
BufferSize: cfg.BufferSize,
|
||||||
|
RetryPolicy: retryPolicy,
|
||||||
|
InstanceID: getInstanceID(cfg.InstanceID),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create broker
|
||||||
|
broker, err := NewBroker(opts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create broker: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start broker
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
return fmt.Errorf("failed to start broker: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set as default
|
||||||
|
SetDefaultBroker(broker)
|
||||||
|
|
||||||
|
// Register shutdown callback
|
||||||
|
RegisterShutdown(broker)
|
||||||
|
|
||||||
|
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
|
||||||
|
cfg.Provider, cfg.Mode, opts.InstanceID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultBroker sets the default global broker
|
||||||
|
func SetDefaultBroker(broker Broker) {
|
||||||
|
brokerMu.Lock()
|
||||||
|
defer brokerMu.Unlock()
|
||||||
|
defaultBroker = broker
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultBroker returns the default global broker
|
||||||
|
func GetDefaultBroker() Broker {
|
||||||
|
brokerMu.RLock()
|
||||||
|
defer brokerMu.RUnlock()
|
||||||
|
return defaultBroker
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized returns true if the default broker is initialized
|
||||||
|
func IsInitialized() bool {
|
||||||
|
return GetDefaultBroker() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event using the default broker
|
||||||
|
func Publish(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Publish(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously using the default broker
|
||||||
|
func PublishSync(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.PublishSync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously using the default broker
|
||||||
|
func PublishAsync(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.PublishAsync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to events using the default broker
|
||||||
|
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return "", fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Subscribe(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe unsubscribes from events using the default broker
|
||||||
|
func Unsubscribe(id SubscriptionID) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Unsubscribe(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics from the default broker
|
||||||
|
func Stats(ctx context.Context) (*BrokerStats, error) {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return nil, fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Stats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterShutdown registers the broker's shutdown with the server shutdown callbacks
|
||||||
|
func RegisterShutdown(broker Broker) {
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
logger.Info("Shutting down event broker...")
|
||||||
|
return broker.Stop(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
266
pkg/eventbroker/example_usage.go
Normal file
266
pkg/eventbroker/example_usage.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
// nolint
|
||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example demonstrates basic usage of the event broker
|
||||||
|
func Example() {
|
||||||
|
// 1. Create a memory provider
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "example-instance",
|
||||||
|
MaxEvents: 1000,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
MaxAge: 1 * time.Hour,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Create a broker
|
||||||
|
broker, err := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 5,
|
||||||
|
BufferSize: 100,
|
||||||
|
RetryPolicy: DefaultRetryPolicy(),
|
||||||
|
InstanceID: "example-instance",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create broker: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Start the broker
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
logger.Error("Failed to start broker: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := broker.Stop(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to stop broker: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 4. Subscribe to events
|
||||||
|
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// 5. Publish events
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Database event
|
||||||
|
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
|
||||||
|
dbEvent.InstanceID = "example-instance"
|
||||||
|
dbEvent.UserID = 123
|
||||||
|
dbEvent.SessionID = "session-456"
|
||||||
|
dbEvent.Schema = "public"
|
||||||
|
dbEvent.Entity = "users"
|
||||||
|
dbEvent.Operation = "create"
|
||||||
|
dbEvent.SetPayload(map[string]interface{}{
|
||||||
|
"id": 123,
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
|
||||||
|
logger.Error("Failed to publish event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket event
|
||||||
|
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
|
||||||
|
wsEvent.InstanceID = "example-instance"
|
||||||
|
wsEvent.UserID = 123
|
||||||
|
wsEvent.SessionID = "session-456"
|
||||||
|
wsEvent.SetPayload(map[string]interface{}{
|
||||||
|
"room": "general",
|
||||||
|
"message": "Hello, World!",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
|
||||||
|
logger.Error("Failed to publish event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Get statistics
|
||||||
|
time.Sleep(1 * time.Second) // Wait for processing
|
||||||
|
stats, _ := broker.Stats(ctx)
|
||||||
|
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleWithHooks demonstrates integration with the hook system
|
||||||
|
func ExampleWithHooks() {
|
||||||
|
// This would typically be called in your main.go or initialization code
|
||||||
|
// after setting up your restheadspec.Handler
|
||||||
|
|
||||||
|
// Pseudo-code (actual implementation would use real handler):
|
||||||
|
/*
|
||||||
|
broker := eventbroker.GetDefaultBroker()
|
||||||
|
hookRegistry := handler.Hooks()
|
||||||
|
|
||||||
|
// Register CRUD hooks
|
||||||
|
config := eventbroker.DefaultCRUDHookConfig()
|
||||||
|
config.EnableRead = false // Disable read events for performance
|
||||||
|
|
||||||
|
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
|
||||||
|
logger.Error("Failed to register CRUD hooks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now all CRUD operations will automatically publish events
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleSubscriptionPatterns demonstrates different subscription patterns
|
||||||
|
func ExampleSubscriptionPatterns() {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 1: Subscribe to all events from a specific entity
|
||||||
|
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("User event: %s\n", event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 2: Subscribe to a specific operation across all entities
|
||||||
|
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 3: Subscribe to all events in a schema
|
||||||
|
broker.Subscribe("public.*.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 4: Subscribe to everything (use with caution)
|
||||||
|
broker.Subscribe("*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Any event: %s\n", event.Type)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleErrorHandling demonstrates error handling in event handlers
|
||||||
|
func ExampleErrorHandling() {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler that may fail
|
||||||
|
broker.Subscribe("public.users.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
// Simulate processing
|
||||||
|
var user struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := event.GetPayload(&user); err != nil {
|
||||||
|
return fmt.Errorf("invalid payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
if user.Email == "" {
|
||||||
|
return fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process (e.g., send email)
|
||||||
|
logger.Info("Sending welcome email to %s", user.Email)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleConfiguration demonstrates initializing from configuration
|
||||||
|
func ExampleConfiguration() {
|
||||||
|
// This would typically be in your main.go
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
/*
|
||||||
|
// Load configuration
|
||||||
|
cfgMgr := config.NewManager()
|
||||||
|
if err := cfgMgr.Load(); err != nil {
|
||||||
|
logger.Fatal("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cfgMgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize event broker
|
||||||
|
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||||
|
logger.Fatal("Failed to initialize event broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the default broker
|
||||||
|
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
logger.Info("Created: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleYAMLConfiguration shows example YAML configuration
|
||||||
|
const ExampleYAMLConfiguration = `
|
||||||
|
event_broker:
|
||||||
|
enabled: true
|
||||||
|
provider: memory # memory, redis, nats, database
|
||||||
|
mode: async # sync, async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
instance_id: "${HOSTNAME}"
|
||||||
|
|
||||||
|
# Memory provider is default, no additional config needed
|
||||||
|
|
||||||
|
# Redis provider (when provider: redis)
|
||||||
|
redis:
|
||||||
|
stream_name: "resolvespec:events"
|
||||||
|
consumer_group: "resolvespec-workers"
|
||||||
|
host: "localhost"
|
||||||
|
port: 6379
|
||||||
|
|
||||||
|
# NATS provider (when provider: nats)
|
||||||
|
nats:
|
||||||
|
url: "nats://localhost:4222"
|
||||||
|
stream_name: "RESOLVESPEC_EVENTS"
|
||||||
|
|
||||||
|
# Database provider (when provider: database)
|
||||||
|
database:
|
||||||
|
table_name: "events"
|
||||||
|
channel: "resolvespec_events"
|
||||||
|
|
||||||
|
# Retry policy
|
||||||
|
retry_policy:
|
||||||
|
max_retries: 3
|
||||||
|
initial_delay: 1s
|
||||||
|
max_delay: 30s
|
||||||
|
backoff_factor: 2.0
|
||||||
|
`
|
||||||
56
pkg/eventbroker/factory.go
Normal file
56
pkg/eventbroker/factory.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewProviderFromConfig creates a provider based on configuration
|
||||||
|
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
|
||||||
|
switch cfg.Provider {
|
||||||
|
case "memory":
|
||||||
|
cleanupInterval := 5 * time.Minute
|
||||||
|
if cfg.Database.PollInterval > 0 {
|
||||||
|
cleanupInterval = cfg.Database.PollInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: getInstanceID(cfg.InstanceID),
|
||||||
|
MaxEvents: 10000,
|
||||||
|
CleanupInterval: cleanupInterval,
|
||||||
|
}), nil
|
||||||
|
|
||||||
|
case "redis":
|
||||||
|
// Redis provider will be implemented in Phase 8
|
||||||
|
return nil, fmt.Errorf("redis provider not yet implemented")
|
||||||
|
|
||||||
|
case "nats":
|
||||||
|
// NATS provider will be implemented in Phase 9
|
||||||
|
return nil, fmt.Errorf("nats provider not yet implemented")
|
||||||
|
|
||||||
|
case "database":
|
||||||
|
// Database provider will be implemented in Phase 7
|
||||||
|
return nil, fmt.Errorf("database provider not yet implemented")
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInstanceID returns the instance ID, defaulting to hostname if not specified
|
||||||
|
func getInstanceID(configID string) string {
|
||||||
|
if configID != "" {
|
||||||
|
return configID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get hostname
|
||||||
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to a default
|
||||||
|
return "resolvespec-instance"
|
||||||
|
}
|
||||||
17
pkg/eventbroker/handler.go
Normal file
17
pkg/eventbroker/handler.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// EventHandler processes an event
|
||||||
|
type EventHandler interface {
|
||||||
|
Handle(ctx context.Context, event *Event) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventHandlerFunc is a function adapter for EventHandler
|
||||||
|
// This allows using regular functions as event handlers
|
||||||
|
type EventHandlerFunc func(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Handle implements EventHandler
|
||||||
|
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
|
||||||
|
return f(ctx, event)
|
||||||
|
}
|
||||||
137
pkg/eventbroker/hooks.go
Normal file
137
pkg/eventbroker/hooks.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CRUDHookConfig configures which CRUD operations should trigger events
|
||||||
|
type CRUDHookConfig struct {
|
||||||
|
EnableCreate bool
|
||||||
|
EnableRead bool
|
||||||
|
EnableUpdate bool
|
||||||
|
EnableDelete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCRUDHookConfig returns default configuration (all enabled)
|
||||||
|
func DefaultCRUDHookConfig() *CRUDHookConfig {
|
||||||
|
return &CRUDHookConfig{
|
||||||
|
EnableCreate: true,
|
||||||
|
EnableRead: false, // Typically disabled for performance
|
||||||
|
EnableUpdate: true,
|
||||||
|
EnableDelete: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterCRUDHooks registers event hooks for CRUD operations
|
||||||
|
// This integrates with the restheadspec.HookRegistry to automatically
|
||||||
|
// capture database events
|
||||||
|
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("broker cannot be nil")
|
||||||
|
}
|
||||||
|
if hookRegistry == nil {
|
||||||
|
return fmt.Errorf("hookRegistry cannot be nil")
|
||||||
|
}
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultCRUDHookConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hook handler factory
|
||||||
|
createHookHandler := func(operation string) restheadspec.HookFunc {
|
||||||
|
return func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
// Get user context from Go context
|
||||||
|
userCtx, ok := security.GetUserContext(hookCtx.Context)
|
||||||
|
if !ok || userCtx == nil {
|
||||||
|
logger.Debug("No user context found in hook")
|
||||||
|
userCtx = &security.UserContext{} // Empty user context
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create event
|
||||||
|
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
|
||||||
|
event.InstanceID = broker.InstanceID()
|
||||||
|
event.UserID = userCtx.UserID
|
||||||
|
event.SessionID = userCtx.SessionID
|
||||||
|
event.Schema = hookCtx.Schema
|
||||||
|
event.Entity = hookCtx.Entity
|
||||||
|
event.Operation = operation
|
||||||
|
|
||||||
|
// Set payload based on operation
|
||||||
|
var payload interface{}
|
||||||
|
switch operation {
|
||||||
|
case "create":
|
||||||
|
payload = hookCtx.Result
|
||||||
|
case "read":
|
||||||
|
payload = hookCtx.Result
|
||||||
|
case "update":
|
||||||
|
payload = map[string]interface{}{
|
||||||
|
"id": hookCtx.ID,
|
||||||
|
"data": hookCtx.Data,
|
||||||
|
}
|
||||||
|
case "delete":
|
||||||
|
payload = map[string]interface{}{
|
||||||
|
"id": hookCtx.ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload != nil {
|
||||||
|
if err := event.SetPayload(payload); err != nil {
|
||||||
|
logger.Error("Failed to set event payload: %v", err)
|
||||||
|
payload = map[string]interface{}{"error": "failed to serialize payload"}
|
||||||
|
event.Payload, _ = json.Marshal(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata
|
||||||
|
if userCtx.UserName != "" {
|
||||||
|
event.Metadata["user_name"] = userCtx.UserName
|
||||||
|
}
|
||||||
|
if userCtx.Email != "" {
|
||||||
|
event.Metadata["user_email"] = userCtx.Email
|
||||||
|
}
|
||||||
|
if len(userCtx.Roles) > 0 {
|
||||||
|
event.Metadata["user_roles"] = userCtx.Roles
|
||||||
|
}
|
||||||
|
event.Metadata["table_name"] = hookCtx.TableName
|
||||||
|
|
||||||
|
// Publish asynchronously to not block CRUD operation
|
||||||
|
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
|
||||||
|
logger.Error("Failed to publish %s event for %s.%s: %v",
|
||||||
|
operation, hookCtx.Schema, hookCtx.Entity, err)
|
||||||
|
// Don't fail the CRUD operation if event publishing fails
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Published %s event for %s.%s (ID: %s)",
|
||||||
|
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register hooks based on configuration
|
||||||
|
if config.EnableCreate {
|
||||||
|
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
|
||||||
|
logger.Info("Registered event hook for CREATE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableRead {
|
||||||
|
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
|
||||||
|
logger.Info("Registered event hook for READ operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableUpdate {
|
||||||
|
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
|
||||||
|
logger.Info("Registered event hook for UPDATE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableDelete {
|
||||||
|
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
|
||||||
|
logger.Info("Registered event hook for DELETE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
28
pkg/eventbroker/metrics.go
Normal file
28
pkg/eventbroker/metrics.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// recordEventPublished records an event publication metric
|
||||||
|
func recordEventPublished(event *Event) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.RecordEventPublished(string(event.Source), event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordEventProcessed records an event processing metric
|
||||||
|
func recordEventProcessed(event *Event, duration time.Duration) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateQueueSize updates the event queue size metric
|
||||||
|
func updateQueueSize(size int64) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.UpdateEventQueueSize(size)
|
||||||
|
}
|
||||||
|
}
|
||||||
70
pkg/eventbroker/provider.go
Normal file
70
pkg/eventbroker/provider.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider defines the storage backend interface for events
|
||||||
|
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
|
||||||
|
type Provider interface {
|
||||||
|
// Store stores an event
|
||||||
|
Store(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Get retrieves an event by ID
|
||||||
|
Get(ctx context.Context, id string) (*Event, error)
|
||||||
|
|
||||||
|
// List lists events with optional filters
|
||||||
|
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||||
|
|
||||||
|
// UpdateStatus updates the status of an event
|
||||||
|
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
|
||||||
|
|
||||||
|
// Delete deletes an event by ID
|
||||||
|
Delete(ctx context.Context, id string) error
|
||||||
|
|
||||||
|
// Stream returns a channel of events for real-time consumption
|
||||||
|
// Used for cross-instance pub/sub
|
||||||
|
// The channel is closed when the context is canceled or an error occurs
|
||||||
|
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||||
|
|
||||||
|
// Publish publishes an event to all subscribers (for distributed providers)
|
||||||
|
// For in-memory provider, this is the same as Store
|
||||||
|
// For Redis/NATS/Database, this triggers cross-instance delivery
|
||||||
|
Publish(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// Stats returns provider statistics
|
||||||
|
Stats(ctx context.Context) (*ProviderStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventFilter defines filter criteria for listing events
|
||||||
|
type EventFilter struct {
|
||||||
|
Source *EventSource
|
||||||
|
Status *EventStatus
|
||||||
|
UserID *int
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
Operation string
|
||||||
|
InstanceID string
|
||||||
|
StartTime *time.Time
|
||||||
|
EndTime *time.Time
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderStats contains statistics about the provider
|
||||||
|
type ProviderStats struct {
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
TotalEvents int64 `json:"total_events"`
|
||||||
|
PendingEvents int64 `json:"pending_events"`
|
||||||
|
ProcessingEvents int64 `json:"processing_events"`
|
||||||
|
CompletedEvents int64 `json:"completed_events"`
|
||||||
|
FailedEvents int64 `json:"failed_events"`
|
||||||
|
EventsPublished int64 `json:"events_published"`
|
||||||
|
EventsConsumed int64 `json:"events_consumed"`
|
||||||
|
ActiveSubscribers int `json:"active_subscribers"`
|
||||||
|
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
|
||||||
|
}
|
||||||
446
pkg/eventbroker/provider_memory.go
Normal file
446
pkg/eventbroker/provider_memory.go
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryProvider implements Provider interface using in-memory storage
|
||||||
|
// Features:
|
||||||
|
// - Thread-safe event storage with RW mutex
|
||||||
|
// - LRU eviction when max events reached
|
||||||
|
// - In-process pub/sub (not cross-instance)
|
||||||
|
// - Automatic cleanup of old completed events
|
||||||
|
type MemoryProvider struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
events map[string]*Event
|
||||||
|
eventOrder []string // For LRU tracking
|
||||||
|
subscribers map[string][]chan *Event
|
||||||
|
instanceID string
|
||||||
|
maxEvents int
|
||||||
|
cleanupInterval time.Duration
|
||||||
|
maxAge time.Duration
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
stats MemoryProviderStats
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
stopCleanup chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
isRunning atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProviderStats contains statistics for the memory provider
|
||||||
|
type MemoryProviderStats struct {
|
||||||
|
TotalEvents atomic.Int64
|
||||||
|
PendingEvents atomic.Int64
|
||||||
|
ProcessingEvents atomic.Int64
|
||||||
|
CompletedEvents atomic.Int64
|
||||||
|
FailedEvents atomic.Int64
|
||||||
|
EventsPublished atomic.Int64
|
||||||
|
EventsConsumed atomic.Int64
|
||||||
|
ActiveSubscribers atomic.Int32
|
||||||
|
Evictions atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProviderOptions configures the memory provider
|
||||||
|
type MemoryProviderOptions struct {
|
||||||
|
InstanceID string
|
||||||
|
MaxEvents int
|
||||||
|
CleanupInterval time.Duration
|
||||||
|
MaxAge time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryProvider creates a new in-memory event provider
|
||||||
|
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
|
||||||
|
if opts.MaxEvents == 0 {
|
||||||
|
opts.MaxEvents = 10000 // Default
|
||||||
|
}
|
||||||
|
if opts.CleanupInterval == 0 {
|
||||||
|
opts.CleanupInterval = 5 * time.Minute // Default
|
||||||
|
}
|
||||||
|
if opts.MaxAge == 0 {
|
||||||
|
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := &MemoryProvider{
|
||||||
|
events: make(map[string]*Event),
|
||||||
|
eventOrder: make([]string, 0),
|
||||||
|
subscribers: make(map[string][]chan *Event),
|
||||||
|
instanceID: opts.InstanceID,
|
||||||
|
maxEvents: opts.MaxEvents,
|
||||||
|
cleanupInterval: opts.CleanupInterval,
|
||||||
|
maxAge: opts.MaxAge,
|
||||||
|
stopCleanup: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start cleanup goroutine
|
||||||
|
mp.wg.Add(1)
|
||||||
|
go mp.cleanupLoop()
|
||||||
|
|
||||||
|
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
|
||||||
|
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
|
||||||
|
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store stores an event
|
||||||
|
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if we need to evict oldest events
|
||||||
|
if len(mp.events) >= mp.maxEvents {
|
||||||
|
mp.evictOldestLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event
|
||||||
|
mp.events[event.ID] = event.Clone()
|
||||||
|
mp.eventOrder = append(mp.eventOrder, event.ID)
|
||||||
|
|
||||||
|
// Update statistics
|
||||||
|
mp.stats.TotalEvents.Add(1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, 1)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an event by ID
|
||||||
|
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return event.Clone(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List lists events with optional filters
|
||||||
|
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
var results []*Event
|
||||||
|
|
||||||
|
for _, event := range mp.events {
|
||||||
|
if mp.matchesFilter(event, filter) {
|
||||||
|
results = append(results, event.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit and offset
|
||||||
|
if filter != nil {
|
||||||
|
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||||
|
results = results[filter.Offset:]
|
||||||
|
}
|
||||||
|
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||||
|
results = results[:filter.Limit]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus updates the status of an event
|
||||||
|
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update status counts
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
mp.updateStatusCountsLocked(status, 1)
|
||||||
|
|
||||||
|
// Update event
|
||||||
|
event.Status = status
|
||||||
|
if errorMsg != "" {
|
||||||
|
event.Error = errorMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes an event by ID
|
||||||
|
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update counts
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
|
||||||
|
// Delete event
|
||||||
|
delete(mp.events, id)
|
||||||
|
|
||||||
|
// Remove from order tracking
|
||||||
|
for i, eid := range mp.eventOrder {
|
||||||
|
if eid == id {
|
||||||
|
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream returns a channel of events for real-time consumption
|
||||||
|
// Note: This is in-process only, not cross-instance
|
||||||
|
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Create buffered channel for events
|
||||||
|
ch := make(chan *Event, 100)
|
||||||
|
|
||||||
|
// Store subscriber
|
||||||
|
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
|
||||||
|
mp.stats.ActiveSubscribers.Add(1)
|
||||||
|
|
||||||
|
// Goroutine to clean up on context cancellation
|
||||||
|
mp.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer mp.wg.Done()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove subscriber
|
||||||
|
subs := mp.subscribers[pattern]
|
||||||
|
for i, subCh := range subs {
|
||||||
|
if subCh == ch {
|
||||||
|
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.stats.ActiveSubscribers.Add(-1)
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Debug("Stream created for pattern: %s", pattern)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event to all subscribers
|
||||||
|
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
|
||||||
|
// Store the event first
|
||||||
|
if err := mp.Store(ctx, event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.stats.EventsPublished.Add(1)
|
||||||
|
|
||||||
|
// Notify subscribers
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
for pattern, channels := range mp.subscribers {
|
||||||
|
if matchPattern(pattern, event.Type) {
|
||||||
|
for _, ch := range channels {
|
||||||
|
select {
|
||||||
|
case ch <- event.Clone():
|
||||||
|
mp.stats.EventsConsumed.Add(1)
|
||||||
|
default:
|
||||||
|
// Channel full, skip
|
||||||
|
logger.Warn("Subscriber channel full for pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
func (mp *MemoryProvider) Close() error {
|
||||||
|
if !mp.isRunning.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Stop cleanup loop
|
||||||
|
close(mp.stopCleanup)
|
||||||
|
|
||||||
|
// Wait for goroutines
|
||||||
|
mp.wg.Wait()
|
||||||
|
|
||||||
|
// Close all subscriber channels
|
||||||
|
mp.mu.Lock()
|
||||||
|
for _, channels := range mp.subscribers {
|
||||||
|
for _, ch := range channels {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.subscribers = make(map[string][]chan *Event)
|
||||||
|
mp.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Memory provider closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns provider statistics
|
||||||
|
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||||
|
return &ProviderStats{
|
||||||
|
ProviderType: "memory",
|
||||||
|
TotalEvents: mp.stats.TotalEvents.Load(),
|
||||||
|
PendingEvents: mp.stats.PendingEvents.Load(),
|
||||||
|
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
|
||||||
|
CompletedEvents: mp.stats.CompletedEvents.Load(),
|
||||||
|
FailedEvents: mp.stats.FailedEvents.Load(),
|
||||||
|
EventsPublished: mp.stats.EventsPublished.Load(),
|
||||||
|
EventsConsumed: mp.stats.EventsConsumed.Load(),
|
||||||
|
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
|
||||||
|
ProviderSpecific: map[string]interface{}{
|
||||||
|
"max_events": mp.maxEvents,
|
||||||
|
"cleanup_interval": mp.cleanupInterval.String(),
|
||||||
|
"max_age": mp.maxAge.String(),
|
||||||
|
"evictions": mp.stats.Evictions.Load(),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupLoop periodically cleans up old completed events
|
||||||
|
func (mp *MemoryProvider) cleanupLoop() {
|
||||||
|
defer mp.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(mp.cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
mp.cleanup()
|
||||||
|
case <-mp.stopCleanup:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes old completed/failed events
|
||||||
|
func (mp *MemoryProvider) cleanup() {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
cutoff := time.Now().Add(-mp.maxAge)
|
||||||
|
removed := 0
|
||||||
|
|
||||||
|
for id, event := range mp.events {
|
||||||
|
// Only clean up completed or failed events that are old
|
||||||
|
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
|
||||||
|
event.CreatedAt.Before(cutoff) {
|
||||||
|
|
||||||
|
delete(mp.events, id)
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
|
||||||
|
// Remove from order tracking
|
||||||
|
for i, eid := range mp.eventOrder {
|
||||||
|
if eid == id {
|
||||||
|
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Debug("Cleanup removed %d old events", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestLocked evicts the oldest event (LRU)
|
||||||
|
// Caller must hold write lock
|
||||||
|
func (mp *MemoryProvider) evictOldestLocked() {
|
||||||
|
if len(mp.eventOrder) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get oldest event ID
|
||||||
|
oldestID := mp.eventOrder[0]
|
||||||
|
mp.eventOrder = mp.eventOrder[1:]
|
||||||
|
|
||||||
|
// Remove event
|
||||||
|
if event, exists := mp.events[oldestID]; exists {
|
||||||
|
delete(mp.events, oldestID)
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
mp.stats.Evictions.Add(1)
|
||||||
|
|
||||||
|
logger.Debug("Evicted oldest event: %s", oldestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesFilter checks if an event matches the filter criteria
|
||||||
|
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||||
|
if filter == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Source != nil && event.Source != *filter.Source {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Status != nil && event.Status != *filter.Status {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStatusCountsLocked updates status statistics
|
||||||
|
// Caller must hold write lock
|
||||||
|
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
|
||||||
|
switch status {
|
||||||
|
case EventStatusPending:
|
||||||
|
mp.stats.PendingEvents.Add(delta)
|
||||||
|
case EventStatusProcessing:
|
||||||
|
mp.stats.ProcessingEvents.Add(delta)
|
||||||
|
case EventStatusCompleted:
|
||||||
|
mp.stats.CompletedEvents.Add(delta)
|
||||||
|
case EventStatusFailed:
|
||||||
|
mp.stats.FailedEvents.Add(delta)
|
||||||
|
}
|
||||||
|
}
|
||||||
419
pkg/eventbroker/provider_memory_test.go
Normal file
419
pkg/eventbroker/provider_memory_test.go
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewMemoryProvider(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 100,
|
||||||
|
CleanupInterval: 1 * time.Minute,
|
||||||
|
})
|
||||||
|
|
||||||
|
if provider == nil {
|
||||||
|
t.Fatal("Expected non-nil provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := provider.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.ProviderType != "memory" {
|
||||||
|
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderPublishAndGet(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.UserID = 123
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
if err := provider.Publish(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("Publish failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get event
|
||||||
|
retrieved, err := provider.Get(context.Background(), event.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.ID != event.ID {
|
||||||
|
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
|
||||||
|
}
|
||||||
|
if retrieved.UserID != 123 {
|
||||||
|
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderGetNonExistent(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := provider.Get(context.Background(), "non-existent-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when getting non-existent event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderUpdateStatus(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Update status to processing
|
||||||
|
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieved, _ := provider.Get(context.Background(), event.ID)
|
||||||
|
if retrieved.Status != EventStatusProcessing {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update status to failed with error
|
||||||
|
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieved, _ = provider.Get(context.Background(), event.ID)
|
||||||
|
if retrieved.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
|
||||||
|
}
|
||||||
|
if retrieved.Error != "test error" {
|
||||||
|
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderList(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List all events
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 5 {
|
||||||
|
t.Errorf("Expected 5 events, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderListWithFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events with different types
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
event3 := NewEvent(EventSourceWebSocket, "chat.message")
|
||||||
|
provider.Publish(context.Background(), event3)
|
||||||
|
|
||||||
|
// Filter by source
|
||||||
|
source := EventSourceDatabase
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
Source: &source,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Errorf("Expected 2 events with database source, got %d", len(events))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by status
|
||||||
|
status := EventStatusPending
|
||||||
|
events, err = provider.List(context.Background(), &EventFilter{
|
||||||
|
Status: &status,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 3 {
|
||||||
|
t.Errorf("Expected 3 events with pending status, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderListWithLimit(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List with limit
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
Limit: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 5 {
|
||||||
|
t.Errorf("Expected 5 events (limited), got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderDelete(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Delete event
|
||||||
|
err := provider.Delete(context.Background(), event.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deleted
|
||||||
|
_, err = provider.Get(context.Background(), event.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when getting deleted event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderLRUEviction(t *testing.T) {
|
||||||
|
// Create provider with small max events
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 3,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish 5 events
|
||||||
|
events := make([]*Event, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
events[i] = NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), events[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// First 2 events should be evicted
|
||||||
|
_, err := provider.Get(context.Background(), events[0].ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected first event to be evicted")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = provider.Get(context.Background(), events[1].ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected second event to be evicted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last 3 events should still exist
|
||||||
|
for i := 2; i < 5; i++ {
|
||||||
|
_, err := provider.Get(context.Background(), events[i].ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected event %d to still exist", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderCleanup(t *testing.T) {
|
||||||
|
// Create provider with short cleanup interval
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
MaxAge: 200 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish and complete an event
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
|
||||||
|
|
||||||
|
// Wait for cleanup to run
|
||||||
|
time.Sleep(400 * time.Millisecond)
|
||||||
|
|
||||||
|
// Event should be cleaned up
|
||||||
|
_, err := provider.Get(context.Background(), event.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected event to be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
provider.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderStats(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 100,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := provider.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.ProviderType != "memory" {
|
||||||
|
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||||
|
}
|
||||||
|
if stats.TotalEvents != 5 {
|
||||||
|
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderClose(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Close provider
|
||||||
|
err := provider.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Close failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup goroutine should be stopped
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderConcurrency(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Concurrent publish
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all events were stored
|
||||||
|
events, _ := provider.List(context.Background(), &EventFilter{})
|
||||||
|
if len(events) != 10 {
|
||||||
|
t.Errorf("Expected 10 events, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderStream(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Stream is implemented for memory provider (in-process pub/sub)
|
||||||
|
ch, err := provider.Stream(context.Background(), "test.*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stream failed: %v", err)
|
||||||
|
}
|
||||||
|
if ch == nil {
|
||||||
|
t.Error("Expected non-nil channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events at different times
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
event3 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event3)
|
||||||
|
|
||||||
|
// Filter by time range
|
||||||
|
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
StartTime: &startTime,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should get events 2 and 3
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Errorf("Expected 2 events after start time, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events with different instance IDs
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
event1.InstanceID = "instance-1"
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
event2.InstanceID = "instance-2"
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
// Filter by instance ID
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
InstanceID: "instance-1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
|
||||||
|
}
|
||||||
|
if events[0].InstanceID != "instance-1" {
|
||||||
|
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
140
pkg/eventbroker/subscription.go
Normal file
140
pkg/eventbroker/subscription.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SubscriptionID uniquely identifies a subscription
|
||||||
|
type SubscriptionID string
|
||||||
|
|
||||||
|
// subscription represents a single subscription with its handler and pattern
|
||||||
|
type subscription struct {
|
||||||
|
id SubscriptionID
|
||||||
|
pattern string
|
||||||
|
handler EventHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscriptionManager manages event subscriptions and pattern matching
|
||||||
|
type subscriptionManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
subscriptions map[SubscriptionID]*subscription
|
||||||
|
nextID atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSubscriptionManager creates a new subscription manager
|
||||||
|
func newSubscriptionManager() *subscriptionManager {
|
||||||
|
return &subscriptionManager{
|
||||||
|
subscriptions: make(map[SubscriptionID]*subscription),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a new subscription
|
||||||
|
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
if pattern == "" {
|
||||||
|
return "", fmt.Errorf("pattern cannot be empty")
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
|
return "", fmt.Errorf("handler cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
|
||||||
|
|
||||||
|
sm.mu.Lock()
|
||||||
|
sm.subscriptions[id] = &subscription{
|
||||||
|
id: id,
|
||||||
|
pattern: pattern,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
sm.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := sm.subscriptions[id]; !exists {
|
||||||
|
return fmt.Errorf("subscription not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sm.subscriptions, id)
|
||||||
|
logger.Info("Unsubscribed: %s", id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMatching returns all handlers that match the event type
|
||||||
|
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
var handlers []EventHandler
|
||||||
|
for _, sub := range sm.subscriptions {
|
||||||
|
if matchPattern(sub.pattern, eventType) {
|
||||||
|
handlers = append(handlers, sub.handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of active subscriptions
|
||||||
|
func (sm *subscriptionManager) Count() int {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
return len(sm.subscriptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all subscriptions
|
||||||
|
func (sm *subscriptionManager) Clear() {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
sm.subscriptions = make(map[SubscriptionID]*subscription)
|
||||||
|
logger.Info("Cleared all subscriptions")
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchPattern implements glob-style pattern matching for event types
|
||||||
|
// Patterns:
|
||||||
|
// - "*" matches any single segment
|
||||||
|
// - "a.b.c" matches exactly "a.b.c"
|
||||||
|
// - "a.*.c" matches "a.anything.c"
|
||||||
|
// - "a.b.*" matches any operation on a.b
|
||||||
|
// - "*" matches everything
|
||||||
|
//
|
||||||
|
// Event type format: schema.entity.operation (e.g., "public.users.create")
|
||||||
|
func matchPattern(pattern, eventType string) bool {
|
||||||
|
// Wildcard matches everything
|
||||||
|
if pattern == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact match
|
||||||
|
if pattern == eventType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split pattern and event type by dots
|
||||||
|
patternParts := strings.Split(pattern, ".")
|
||||||
|
eventParts := strings.Split(eventType, ".")
|
||||||
|
|
||||||
|
// Different number of parts can only match if pattern has wildcards
|
||||||
|
if len(patternParts) != len(eventParts) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match each part
|
||||||
|
for i := range patternParts {
|
||||||
|
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
270
pkg/eventbroker/subscription_test.go
Normal file
270
pkg/eventbroker/subscription_test.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMatchPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
pattern string
|
||||||
|
eventType string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Exact matches
|
||||||
|
{"public.users.create", "public.users.create", true},
|
||||||
|
{"public.users.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Wildcard matches
|
||||||
|
{"*", "public.users.create", true},
|
||||||
|
{"*", "anything", true},
|
||||||
|
{"public.*", "public.users", true},
|
||||||
|
{"public.*", "public.users.create", false}, // Different number of parts
|
||||||
|
{"public.*", "admin.users", false},
|
||||||
|
{"*.users.create", "public.users.create", true},
|
||||||
|
{"*.users.create", "admin.users.create", true},
|
||||||
|
{"*.users.create", "public.roles.create", false},
|
||||||
|
{"public.*.create", "public.users.create", true},
|
||||||
|
{"public.*.create", "public.roles.create", true},
|
||||||
|
{"public.*.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Multiple wildcards
|
||||||
|
{"*.*", "public.users", true},
|
||||||
|
{"*.*", "public.users.create", false}, // Different number of parts
|
||||||
|
{"*.*.create", "public.users.create", true},
|
||||||
|
{"*.*.create", "admin.roles.create", true},
|
||||||
|
{"*.*.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Edge cases
|
||||||
|
{"", "", true},
|
||||||
|
{"", "something", false},
|
||||||
|
{"something", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
|
||||||
|
result := matchPattern(tt.pattern, tt.eventType)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
|
||||||
|
tt.pattern, tt.eventType, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManager(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
called := false
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Subscribe
|
||||||
|
id, err := manager.Subscribe("public.users.*", handler)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
t.Fatal("Expected non-empty subscription ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetMatching
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 1 {
|
||||||
|
t.Fatalf("Expected 1 handler, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test handler execution
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
if err := handlers[0].Handle(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("Handler execution failed: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Count
|
||||||
|
if manager.Count() != 1 {
|
||||||
|
t.Errorf("Expected count 1, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Unsubscribe
|
||||||
|
if err := manager.Unsubscribe(id); err != nil {
|
||||||
|
t.Fatalf("Unsubscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify unsubscribed
|
||||||
|
handlers = manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 0 {
|
||||||
|
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
called1 := false
|
||||||
|
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called1 = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
called2 := false
|
||||||
|
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called2 = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe multiple handlers
|
||||||
|
id1, _ := manager.Subscribe("public.users.*", handler1)
|
||||||
|
id2, _ := manager.Subscribe("*.users.*", handler2)
|
||||||
|
|
||||||
|
// Both should match
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 2 {
|
||||||
|
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
for _, h := range handlers {
|
||||||
|
h.Handle(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called1 || !called2 {
|
||||||
|
t.Error("Expected both handlers to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe one
|
||||||
|
manager.Unsubscribe(id1)
|
||||||
|
handlers = manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 1 {
|
||||||
|
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe remaining
|
||||||
|
manager.Unsubscribe(id2)
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerConcurrency(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe and unsubscribe concurrently
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
id, _ := manager.Subscribe("test.*", handler)
|
||||||
|
manager.GetMatching("test.event")
|
||||||
|
manager.Unsubscribe(id)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have no subscriptions left
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// Try to unsubscribe a non-existent ID
|
||||||
|
err := manager.Unsubscribe("non-existent-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when unsubscribing non-existent ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionIDGeneration(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe multiple times and ensure unique IDs
|
||||||
|
ids := make(map[SubscriptionID]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id, _ := manager.Subscribe("test.*", handler)
|
||||||
|
if ids[id] {
|
||||||
|
t.Fatalf("Duplicate subscription ID: %s", id)
|
||||||
|
}
|
||||||
|
ids[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventHandlerFunc(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
var receivedEvent *Event
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
receivedEvent = event
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
err := handler.Handle(context.Background(), event)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
if receivedEvent != event {
|
||||||
|
t.Error("Expected to receive the same event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerPatternPriority(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// More specific patterns should still match
|
||||||
|
specificCalled := false
|
||||||
|
genericCalled := false
|
||||||
|
|
||||||
|
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
specificCalled = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
genericCalled = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 2 {
|
||||||
|
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
for _, h := range handlers {
|
||||||
|
h.Handle(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !specificCalled || !genericCalled {
|
||||||
|
t.Error("Expected both specific and generic handlers to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
141
pkg/eventbroker/worker_pool.go
Normal file
141
pkg/eventbroker/worker_pool.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// workerPool manages a pool of workers for async event processing
|
||||||
|
type workerPool struct {
|
||||||
|
workerCount int
|
||||||
|
bufferSize int
|
||||||
|
eventQueue chan *Event
|
||||||
|
processor func(context.Context, *Event) error
|
||||||
|
|
||||||
|
activeWorkers atomic.Int32
|
||||||
|
isRunning atomic.Bool
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWorkerPool creates a new worker pool
|
||||||
|
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
|
||||||
|
return &workerPool{
|
||||||
|
workerCount: workerCount,
|
||||||
|
bufferSize: bufferSize,
|
||||||
|
eventQueue: make(chan *Event, bufferSize),
|
||||||
|
processor: processor,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the worker pool
|
||||||
|
func (wp *workerPool) Start() {
|
||||||
|
if wp.isRunning.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start workers
|
||||||
|
for i := 0; i < wp.workerCount; i++ {
|
||||||
|
wp.wg.Add(1)
|
||||||
|
go wp.worker(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Worker pool started with %d workers", wp.workerCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the worker pool gracefully
|
||||||
|
func (wp *workerPool) Stop(ctx context.Context) error {
|
||||||
|
if !wp.isRunning.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Close event queue to signal workers
|
||||||
|
close(wp.eventQueue)
|
||||||
|
|
||||||
|
// Wait for workers to finish with context timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wp.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.Info("Worker pool stopped gracefully")
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Warn("Worker pool stop timed out, some events may be lost")
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit submits an event to the queue
|
||||||
|
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
|
||||||
|
if !wp.isRunning.Load() {
|
||||||
|
return ErrWorkerPoolStopped
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case wp.eventQueue <- event:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return ErrQueueFull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is a worker goroutine that processes events from the queue
|
||||||
|
func (wp *workerPool) worker(id int) {
|
||||||
|
defer wp.wg.Done()
|
||||||
|
|
||||||
|
logger.Debug("Worker %d started", id)
|
||||||
|
|
||||||
|
for event := range wp.eventQueue {
|
||||||
|
wp.activeWorkers.Add(1)
|
||||||
|
|
||||||
|
// Process event with background context (detached from original request)
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := wp.processor(ctx, event); err != nil {
|
||||||
|
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.activeWorkers.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Worker %d stopped", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueSize returns the current queue size
|
||||||
|
func (wp *workerPool) QueueSize() int {
|
||||||
|
return len(wp.eventQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActiveWorkers returns the number of currently active workers
|
||||||
|
func (wp *workerPool) ActiveWorkers() int {
|
||||||
|
return int(wp.activeWorkers.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error definitions
|
||||||
|
var (
|
||||||
|
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
|
||||||
|
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// BrokerError represents an error from the event broker
|
||||||
|
type BrokerError struct {
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BrokerError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
@@ -22,15 +22,23 @@ import (
|
|||||||
type Handler struct {
|
type Handler struct {
|
||||||
db common.Database
|
db common.Database
|
||||||
hooks *HookRegistry
|
hooks *HookRegistry
|
||||||
|
variablesCallback func(r *http.Request) map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type SqlQueryOptions struct {
|
type SqlQueryOptions struct {
|
||||||
GetVariablesCallback func(w http.ResponseWriter, r *http.Request) map[string]interface{}
|
|
||||||
NoCount bool
|
NoCount bool
|
||||||
BlankParams bool
|
BlankParams bool
|
||||||
AllowFilter bool
|
AllowFilter bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewSqlQueryOptions() SqlQueryOptions {
|
||||||
|
return SqlQueryOptions{
|
||||||
|
NoCount: false,
|
||||||
|
BlankParams: true,
|
||||||
|
AllowFilter: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewHandler creates a new function API handler
|
// NewHandler creates a new function API handler
|
||||||
func NewHandler(db common.Database) *Handler {
|
func NewHandler(db common.Database) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
@@ -45,6 +53,14 @@ func (h *Handler) GetDatabase() common.Database {
|
|||||||
return h.db
|
return h.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) SetVariablesCallback(callback func(r *http.Request) map[string]interface{}) {
|
||||||
|
h.variablesCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) GetVariablesCallback() func(r *http.Request) map[string]interface{} {
|
||||||
|
return h.variablesCallback
|
||||||
|
}
|
||||||
|
|
||||||
// Hooks returns the hook registry for this handler
|
// Hooks returns the hook registry for this handler
|
||||||
// Use this to register custom hooks for operations
|
// Use this to register custom hooks for operations
|
||||||
func (h *Handler) Hooks() *HookRegistry {
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
@@ -77,9 +93,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
inputvars := make([]string, 0)
|
inputvars := make([]string, 0)
|
||||||
metainfo := make(map[string]interface{})
|
metainfo := make(map[string]interface{})
|
||||||
variables := make(map[string]interface{})
|
variables := make(map[string]interface{})
|
||||||
if options.GetVariablesCallback != nil {
|
|
||||||
variables = options.GetVariablesCallback(w, r)
|
|
||||||
}
|
|
||||||
complexAPI := false
|
complexAPI := false
|
||||||
|
|
||||||
// Get user context from security package
|
// Get user context from security package
|
||||||
@@ -416,9 +430,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
inputvars := make([]string, 0)
|
inputvars := make([]string, 0)
|
||||||
metainfo := make(map[string]interface{})
|
metainfo := make(map[string]interface{})
|
||||||
variables := make(map[string]interface{})
|
variables := make(map[string]interface{})
|
||||||
if options.GetVariablesCallback != nil {
|
|
||||||
variables = options.GetVariablesCallback(w, r)
|
|
||||||
}
|
|
||||||
dbobj := make(map[string]interface{})
|
dbobj := make(map[string]interface{})
|
||||||
complexAPI := false
|
complexAPI := false
|
||||||
|
|
||||||
@@ -644,8 +656,18 @@ func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) st
|
|||||||
|
|
||||||
// mergePathParams merges URL path parameters into the SQL query
|
// mergePathParams merges URL path parameters into the SQL query
|
||||||
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
||||||
// Note: Path parameters would typically come from a router like gorilla/mux
|
|
||||||
// For now, this is a placeholder for path parameter extraction
|
if h.GetVariablesCallback() != nil {
|
||||||
|
pathVars := h.GetVariablesCallback()(r)
|
||||||
|
for k, v := range pathVars {
|
||||||
|
kword := fmt.Sprintf("[%s]", k)
|
||||||
|
if strings.Contains(sqlquery, kword) {
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
variables[k] = v
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
return sqlquery
|
return sqlquery
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Data
|
|||||||
return fn(m)
|
return fn(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// MockResult implements common.Result interface for testing
|
// MockResult implements common.Result interface for testing
|
||||||
type MockResult struct {
|
type MockResult struct {
|
||||||
rows int64
|
rows int64
|
||||||
|
|||||||
@@ -75,6 +75,25 @@ func CloseErrorTracking() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractContext attempts to find a context.Context in the given arguments.
|
||||||
|
// It returns the found context (or context.Background() if not found) and
|
||||||
|
// the remaining arguments without the context.
|
||||||
|
func extractContext(args ...interface{}) (context.Context, []interface{}) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var newArgs []interface{}
|
||||||
|
found := false
|
||||||
|
|
||||||
|
for _, arg := range args {
|
||||||
|
if c, ok := arg.(context.Context); ok && !found {
|
||||||
|
ctx = c
|
||||||
|
found = true
|
||||||
|
} else {
|
||||||
|
newArgs = append(newArgs, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, newArgs
|
||||||
|
}
|
||||||
|
|
||||||
func Info(template string, args ...interface{}) {
|
func Info(template string, args ...interface{}) {
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf(template, args...)
|
log.Printf(template, args...)
|
||||||
@@ -84,7 +103,8 @@ func Info(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Warn(template string, args ...interface{}) {
|
func Warn(template string, args ...interface{}) {
|
||||||
message := fmt.Sprintf(template, args...)
|
ctx, remainingArgs := extractContext(args...)
|
||||||
|
message := fmt.Sprintf(template, remainingArgs...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf("%s", message)
|
log.Printf("%s", message)
|
||||||
} else {
|
} else {
|
||||||
@@ -93,14 +113,15 @@ func Warn(template string, args ...interface{}) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
|
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityWarning, map[string]interface{}{
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Error(template string, args ...interface{}) {
|
func Error(template string, args ...interface{}) {
|
||||||
message := fmt.Sprintf(template, args...)
|
ctx, remainingArgs := extractContext(args...)
|
||||||
|
message := fmt.Sprintf(template, remainingArgs...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf("%s", message)
|
log.Printf("%s", message)
|
||||||
} else {
|
} else {
|
||||||
@@ -109,7 +130,7 @@ func Error(template string, args ...interface{}) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
|
errorTracker.CaptureMessage(ctx, message, errortracking.SeverityError, map[string]interface{}{
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -124,12 +145,13 @@ func Debug(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CatchPanic - Handle panic
|
// CatchPanic - Handle panic
|
||||||
func CatchPanicCallback(location string, cb func(err any)) {
|
func CatchPanicCallback(location string, cb func(err any), args ...interface{}) {
|
||||||
|
ctx, _ := extractContext(args...)
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
callstack := debug.Stack()
|
callstack := debug.Stack()
|
||||||
|
|
||||||
if Logger != nil {
|
if Logger != nil {
|
||||||
Error("Panic in %s : %v", location, err)
|
Error("Panic in %s : %v", location, err, ctx) // Pass context implicitly
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||||
debug.PrintStack()
|
debug.PrintStack()
|
||||||
@@ -137,7 +159,7 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
|
errorTracker.CapturePanic(ctx, err, callstack, map[string]interface{}{
|
||||||
"location": location,
|
"location": location,
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
@@ -150,8 +172,8 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CatchPanic - Handle panic
|
// CatchPanic - Handle panic
|
||||||
func CatchPanic(location string) {
|
func CatchPanic(location string, args ...interface{}) {
|
||||||
CatchPanicCallback(location, nil)
|
CatchPanicCallback(location, nil, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlePanic logs a panic and returns it as an error
|
// HandlePanic logs a panic and returns it as an error
|
||||||
@@ -163,13 +185,14 @@ func CatchPanic(location string) {
|
|||||||
// err = logger.HandlePanic("MethodName", r)
|
// err = logger.HandlePanic("MethodName", r)
|
||||||
// }
|
// }
|
||||||
// }()
|
// }()
|
||||||
func HandlePanic(methodName string, r any) error {
|
func HandlePanic(methodName string, r any, args ...interface{}) error {
|
||||||
|
ctx, _ := extractContext(args...)
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack), ctx) // Pass context implicitly
|
||||||
|
|
||||||
// Send to error tracker
|
// Send to error tracker
|
||||||
if errorTracker != nil {
|
if errorTracker != nil {
|
||||||
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
|
errorTracker.CapturePanic(ctx, r, stack, map[string]interface{}{
|
||||||
"method": methodName,
|
"method": methodName,
|
||||||
"process_id": os.Getpid(),
|
"process_id": os.Getpid(),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -30,6 +30,18 @@ type Provider interface {
|
|||||||
// UpdateCacheSize updates the cache size metric
|
// UpdateCacheSize updates the cache size metric
|
||||||
UpdateCacheSize(provider string, size int64)
|
UpdateCacheSize(provider string, size int64)
|
||||||
|
|
||||||
|
// RecordEventPublished records an event publication
|
||||||
|
RecordEventPublished(source, eventType string)
|
||||||
|
|
||||||
|
// RecordEventProcessed records an event processing with its status
|
||||||
|
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||||
|
|
||||||
|
// UpdateEventQueueSize updates the event queue size metric
|
||||||
|
UpdateEventQueueSize(size int64)
|
||||||
|
|
||||||
|
// RecordPanic records a panic event
|
||||||
|
RecordPanic(methodName string)
|
||||||
|
|
||||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||||
Handler() http.Handler
|
Handler() http.Handler
|
||||||
}
|
}
|
||||||
@@ -62,6 +74,11 @@ func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Dura
|
|||||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||||
|
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
||||||
|
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||||
|
}
|
||||||
|
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||||
|
func (n *NoOpProvider) RecordPanic(methodName string) {}
|
||||||
func (n *NoOpProvider) Handler() http.Handler {
|
func (n *NoOpProvider) Handler() http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type PrometheusProvider struct {
|
|||||||
cacheHits *prometheus.CounterVec
|
cacheHits *prometheus.CounterVec
|
||||||
cacheMisses *prometheus.CounterVec
|
cacheMisses *prometheus.CounterVec
|
||||||
cacheSize *prometheus.GaugeVec
|
cacheSize *prometheus.GaugeVec
|
||||||
|
panicsTotal *prometheus.CounterVec
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||||
@@ -83,6 +84,13 @@ func NewPrometheusProvider() *PrometheusProvider {
|
|||||||
},
|
},
|
||||||
[]string{"provider"},
|
[]string{"provider"},
|
||||||
),
|
),
|
||||||
|
panicsTotal: promauto.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "panics_total",
|
||||||
|
Help: "Total number of panics",
|
||||||
|
},
|
||||||
|
[]string{"method"},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,6 +153,11 @@ func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
|||||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordPanic implements the Provider interface
|
||||||
|
func (p *PrometheusProvider) RecordPanic(methodName string) {
|
||||||
|
p.panicsTotal.WithLabelValues(methodName).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
// Handler implements Provider interface
|
// Handler implements Provider interface
|
||||||
func (p *PrometheusProvider) Handler() http.Handler {
|
func (p *PrometheusProvider) Handler() http.Handler {
|
||||||
return promhttp.Handler()
|
return promhttp.Handler()
|
||||||
|
|||||||
33
pkg/middleware/panic.go
Normal file
33
pkg/middleware/panic.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
const panicMiddlewareMethodName = "PanicMiddleware"
|
||||||
|
|
||||||
|
// PanicRecovery is a middleware that recovers from panics, logs the error,
|
||||||
|
// sends it to an error tracker, records a metric, and returns a 500 Internal Server Error.
|
||||||
|
func PanicRecovery(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if rcv := recover(); rcv != nil {
|
||||||
|
// Record the panic metric
|
||||||
|
metrics.GetProvider().RecordPanic(panicMiddlewareMethodName)
|
||||||
|
|
||||||
|
// Log the panic and send to error tracker
|
||||||
|
// We pass the request context so the error tracker can potentially
|
||||||
|
// link the panic to the request trace.
|
||||||
|
ctx := r.Context()
|
||||||
|
err := logger.HandlePanic(panicMiddlewareMethodName, rcv, ctx)
|
||||||
|
|
||||||
|
// Respond with a 500 error
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
86
pkg/middleware/panic_test.go
Normal file
86
pkg/middleware/panic_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockMetricsProvider is a mock for the metrics provider to check if methods are called.
|
||||||
|
type mockMetricsProvider struct {
|
||||||
|
metrics.NoOpProvider // Embed NoOpProvider to avoid implementing all methods
|
||||||
|
panicRecorded bool
|
||||||
|
methodName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMetricsProvider) RecordPanic(methodName string) {
|
||||||
|
m.panicRecorded = true
|
||||||
|
m.methodName = methodName
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPanicRecovery(t *testing.T) {
|
||||||
|
// Initialize a mock logger to avoid actual logging output during tests
|
||||||
|
logger.Init(true)
|
||||||
|
|
||||||
|
// Setup mock metrics provider
|
||||||
|
mockProvider := &mockMetricsProvider{}
|
||||||
|
originalProvider := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(mockProvider)
|
||||||
|
defer metrics.SetProvider(originalProvider) // Restore original provider after test
|
||||||
|
|
||||||
|
// 1. Test case: A handler that panics
|
||||||
|
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||||
|
// Reset mock state for this sub-test
|
||||||
|
mockProvider.panicRecorded = false
|
||||||
|
mockProvider.methodName = ""
|
||||||
|
|
||||||
|
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("something went terribly wrong")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the middleware wrapping the panicking handler
|
||||||
|
testHandler := PanicRecovery(panicHandler)
|
||||||
|
|
||||||
|
// Create a test request and response recorder
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
testHandler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rr.Code, "expected status code to be 500")
|
||||||
|
assert.Contains(t, rr.Body.String(), "panic in PanicMiddleware: something went terribly wrong", "expected error message in response body")
|
||||||
|
|
||||||
|
// Assert that the metric was recorded
|
||||||
|
assert.True(t, mockProvider.panicRecorded, "expected RecordPanic to be called on metrics provider")
|
||||||
|
assert.Equal(t, panicMiddlewareMethodName, mockProvider.methodName, "expected panic to be recorded with the correct method name")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Test case: A handler that does NOT panic
|
||||||
|
t.Run("does not interfere with a non-panicking handler", func(t *testing.T) {
|
||||||
|
// Reset mock state for this sub-test
|
||||||
|
mockProvider.panicRecorded = false
|
||||||
|
|
||||||
|
successHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
testHandler := PanicRecovery(successHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
testHandler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert.Equal(t, http.StatusOK, rr.Code, "expected status code to be 200")
|
||||||
|
assert.Equal(t, "OK", rr.Body.String(), "expected 'OK' response body")
|
||||||
|
assert.False(t, mockProvider.panicRecorded, "expected RecordPanic to not be called when there is no panic")
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -6,15 +6,37 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ModelRules defines the permissions and security settings for a model
|
||||||
|
type ModelRules struct {
|
||||||
|
CanRead bool // Whether the model can be read (GET operations)
|
||||||
|
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
|
CanCreate bool // Whether the model can be created (POST operations)
|
||||||
|
CanDelete bool // Whether the model can be deleted (DELETE operations)
|
||||||
|
SecurityDisabled bool // Whether security checks are disabled for this model
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
|
||||||
|
func DefaultModelRules() ModelRules {
|
||||||
|
return ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanDelete: true,
|
||||||
|
SecurityDisabled: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultModelRegistry implements ModelRegistry interface
|
// DefaultModelRegistry implements ModelRegistry interface
|
||||||
type DefaultModelRegistry struct {
|
type DefaultModelRegistry struct {
|
||||||
models map[string]interface{}
|
models map[string]interface{}
|
||||||
|
rules map[string]ModelRules
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global default registry instance
|
// Global default registry instance
|
||||||
var defaultRegistry = &DefaultModelRegistry{
|
var defaultRegistry = &DefaultModelRegistry{
|
||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
|
rules: make(map[string]ModelRules),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global list of registries (searched in order)
|
// Global list of registries (searched in order)
|
||||||
@@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex
|
|||||||
func NewModelRegistry() *DefaultModelRegistry {
|
func NewModelRegistry() *DefaultModelRegistry {
|
||||||
return &DefaultModelRegistry{
|
return &DefaultModelRegistry{
|
||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
|
rules: make(map[string]ModelRules),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.models[name] = model
|
r.models[name] = model
|
||||||
|
// Initialize with default rules if not already set
|
||||||
|
if _, exists := r.rules[name]; !exists {
|
||||||
|
r.rules[name] = DefaultModelRules()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
|
|||||||
return r.GetModel(entity)
|
return r.GetModel(entity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRules sets the rules for a specific model
|
||||||
|
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
// Check if model exists
|
||||||
|
if _, exists := r.models[name]; !exists {
|
||||||
|
return fmt.Errorf("model %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[name] = rules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRules retrieves the rules for a specific model
|
||||||
|
// Returns default rules if model exists but rules are not set
|
||||||
|
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Check if model exists
|
||||||
|
if _, exists := r.models[name]; !exists {
|
||||||
|
return ModelRules{}, fmt.Errorf("model %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rules if set, otherwise return default rules
|
||||||
|
if rules, exists := r.rules[name]; exists {
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultModelRules(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model with specific rules
|
||||||
|
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
|
||||||
|
// First register the model
|
||||||
|
if err := r.RegisterModel(name, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then set the rules (we need to lock again for rules)
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.rules[name] = rules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Global convenience functions using the default registry
|
// Global convenience functions using the default registry
|
||||||
|
|
||||||
// RegisterModel registers a model with the default global registry
|
// RegisterModel registers a model with the default global registry
|
||||||
@@ -190,3 +265,34 @@ func GetModels() []interface{} {
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRules sets the rules for a specific model in the default registry
|
||||||
|
func SetModelRules(name string, rules ModelRules) error {
|
||||||
|
return defaultRegistry.SetModelRules(name, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRules retrieves the rules for a specific model from the default registry
|
||||||
|
func GetModelRules(name string) (ModelRules, error) {
|
||||||
|
return defaultRegistry.GetModelRules(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
|
||||||
|
// Returns the first match found
|
||||||
|
func GetModelRulesByName(name string) (ModelRules, error) {
|
||||||
|
registriesMutex.RLock()
|
||||||
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
if _, err := registry.GetModel(name); err == nil {
|
||||||
|
// Model found in this registry, get its rules
|
||||||
|
return registry.GetModelRules(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model with specific rules in the default registry
|
||||||
|
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
|
||||||
|
return defaultRegistry.RegisterModelWithRules(name, model, rules)
|
||||||
|
}
|
||||||
|
|||||||
@@ -273,25 +273,151 @@ handler.SetOpenAPIGenerator(func() (string, error) {
|
|||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using with Swagger UI
|
## Using the Built-in UI Handler
|
||||||
|
|
||||||
You can serve the generated OpenAPI spec with Swagger UI:
|
The package includes a built-in UI handler that serves popular OpenAPI visualization tools. No need to download or manage static files - everything is served from CDN.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/openapi"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Setup your API routes and OpenAPI generator...
|
||||||
|
// (see examples above)
|
||||||
|
|
||||||
|
// Add the UI handler - defaults to Swagger UI
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API Documentation",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now visit http://localhost:8080/docs
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported UI Frameworks
|
||||||
|
|
||||||
|
The handler supports four popular OpenAPI UI frameworks:
|
||||||
|
|
||||||
|
#### 1. Swagger UI (Default)
|
||||||
|
The most widely used OpenAPI UI with excellent compatibility and features.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
Theme: "dark", // optional: "light" or "dark"
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. RapiDoc
|
||||||
|
Modern, customizable, and feature-rich OpenAPI UI.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.RapiDoc,
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Redoc
|
||||||
|
Clean, responsive documentation with great UX.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.Redoc,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Scalar
|
||||||
|
Modern and sleek OpenAPI documentation.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.Scalar,
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Options
|
||||||
|
|
||||||
|
```go
|
||||||
|
type UIConfig struct {
|
||||||
|
UIType UIType // SwaggerUI, RapiDoc, Redoc, or Scalar
|
||||||
|
SpecURL string // URL to OpenAPI spec (default: "/openapi")
|
||||||
|
Title string // Page title (default: "API Documentation")
|
||||||
|
FaviconURL string // Custom favicon URL (optional)
|
||||||
|
CustomCSS string // Custom CSS to inject (optional)
|
||||||
|
Theme string // "light" or "dark" (support varies by UI)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Styling Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
Title: "Acme Corp API",
|
||||||
|
CustomCSS: `
|
||||||
|
.swagger-ui .topbar {
|
||||||
|
background-color: #1976d2;
|
||||||
|
}
|
||||||
|
.swagger-ui .info .title {
|
||||||
|
color: #1976d2;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Multiple UIs
|
||||||
|
|
||||||
|
You can serve different UIs at different paths:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Swagger UI at /docs
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redoc at /redoc
|
||||||
|
openapi.SetupUIRoute(router, "/redoc", openapi.UIConfig{
|
||||||
|
UIType: openapi.Redoc,
|
||||||
|
})
|
||||||
|
|
||||||
|
// RapiDoc at /api-docs
|
||||||
|
openapi.SetupUIRoute(router, "/api-docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.RapiDoc,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Handler Usage
|
||||||
|
|
||||||
|
If you need more control, use the handler directly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler := openapi.UIHandler(openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
SpecURL: "/api/openapi.json",
|
||||||
|
})
|
||||||
|
|
||||||
|
router.Handle("/documentation", handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using with External Swagger UI
|
||||||
|
|
||||||
|
Alternatively, you can use an external Swagger UI instance:
|
||||||
|
|
||||||
1. Get the spec from `/openapi`
|
1. Get the spec from `/openapi`
|
||||||
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
||||||
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
||||||
|
|
||||||
Example with self-hosted Swagger UI:
|
|
||||||
|
|
||||||
```go
|
|
||||||
// Serve Swagger UI static files
|
|
||||||
router.PathPrefix("/swagger/").Handler(
|
|
||||||
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Configure Swagger UI to use /openapi
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
You can test the OpenAPI endpoint:
|
You can test the OpenAPI endpoint:
|
||||||
|
|||||||
@@ -183,6 +183,69 @@ func ExampleWithFuncSpec() {
|
|||||||
_ = generatorFunc
|
_ = generatorFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExampleWithUIHandler shows how to serve OpenAPI documentation with a web UI
|
||||||
|
func ExampleWithUIHandler(db *gorm.DB) {
|
||||||
|
// Create handler and configure OpenAPI generator
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
registry := modelregistry.NewModelRegistry()
|
||||||
|
|
||||||
|
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||||
|
generator := NewGenerator(GeneratorConfig{
|
||||||
|
Title: "My API",
|
||||||
|
Description: "API documentation with interactive UI",
|
||||||
|
Version: "1.0.0",
|
||||||
|
BaseURL: "http://localhost:8080",
|
||||||
|
Registry: registry,
|
||||||
|
IncludeRestheadSpec: true,
|
||||||
|
})
|
||||||
|
return generator.GenerateJSON()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Add UI handlers for different frameworks
|
||||||
|
// Swagger UI at /docs (most popular)
|
||||||
|
SetupUIRoute(router, "/docs", UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Swagger UI",
|
||||||
|
Theme: "light",
|
||||||
|
})
|
||||||
|
|
||||||
|
// RapiDoc at /rapidoc (modern alternative)
|
||||||
|
SetupUIRoute(router, "/rapidoc", UIConfig{
|
||||||
|
UIType: RapiDoc,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - RapiDoc",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redoc at /redoc (clean and responsive)
|
||||||
|
SetupUIRoute(router, "/redoc", UIConfig{
|
||||||
|
UIType: Redoc,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Redoc",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Scalar at /scalar (modern and sleek)
|
||||||
|
SetupUIRoute(router, "/scalar", UIConfig{
|
||||||
|
UIType: Scalar,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Scalar",
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now you can access:
|
||||||
|
// http://localhost:8080/docs - Swagger UI
|
||||||
|
// http://localhost:8080/rapidoc - RapiDoc
|
||||||
|
// http://localhost:8080/redoc - Redoc
|
||||||
|
// http://localhost:8080/scalar - Scalar
|
||||||
|
// http://localhost:8080/openapi - Raw OpenAPI JSON
|
||||||
|
|
||||||
|
_ = router
|
||||||
|
}
|
||||||
|
|
||||||
// ExampleCustomization shows advanced customization options
|
// ExampleCustomization shows advanced customization options
|
||||||
func ExampleCustomization() {
|
func ExampleCustomization() {
|
||||||
// Create registry and register models with descriptions using struct tags
|
// Create registry and register models with descriptions using struct tags
|
||||||
|
|||||||
294
pkg/openapi/ui_handler.go
Normal file
294
pkg/openapi/ui_handler.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package openapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UIType represents the type of OpenAPI UI to serve
|
||||||
|
type UIType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SwaggerUI is the most popular OpenAPI UI
|
||||||
|
SwaggerUI UIType = "swagger-ui"
|
||||||
|
// RapiDoc is a modern, customizable OpenAPI UI
|
||||||
|
RapiDoc UIType = "rapidoc"
|
||||||
|
// Redoc is a clean, responsive OpenAPI UI
|
||||||
|
Redoc UIType = "redoc"
|
||||||
|
// Scalar is a modern and sleek OpenAPI UI
|
||||||
|
Scalar UIType = "scalar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UIConfig holds configuration for the OpenAPI UI handler
|
||||||
|
type UIConfig struct {
|
||||||
|
// UIType specifies which UI framework to use (default: SwaggerUI)
|
||||||
|
UIType UIType
|
||||||
|
// SpecURL is the URL to the OpenAPI spec JSON (default: "/openapi")
|
||||||
|
SpecURL string
|
||||||
|
// Title is the page title (default: "API Documentation")
|
||||||
|
Title string
|
||||||
|
// FaviconURL is the URL to the favicon (optional)
|
||||||
|
FaviconURL string
|
||||||
|
// CustomCSS allows injecting custom CSS (optional)
|
||||||
|
CustomCSS string
|
||||||
|
// Theme for the UI (light/dark, depends on UI type)
|
||||||
|
Theme string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UIHandler creates an HTTP handler that serves an OpenAPI UI
|
||||||
|
func UIHandler(config UIConfig) http.HandlerFunc {
|
||||||
|
// Set defaults
|
||||||
|
if config.UIType == "" {
|
||||||
|
config.UIType = SwaggerUI
|
||||||
|
}
|
||||||
|
if config.SpecURL == "" {
|
||||||
|
config.SpecURL = "/openapi"
|
||||||
|
}
|
||||||
|
if config.Title == "" {
|
||||||
|
config.Title = "API Documentation"
|
||||||
|
}
|
||||||
|
if config.Theme == "" {
|
||||||
|
config.Theme = "light"
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var htmlContent string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch config.UIType {
|
||||||
|
case SwaggerUI:
|
||||||
|
htmlContent, err = generateSwaggerUI(config)
|
||||||
|
case RapiDoc:
|
||||||
|
htmlContent, err = generateRapiDoc(config)
|
||||||
|
case Redoc:
|
||||||
|
htmlContent, err = generateRedoc(config)
|
||||||
|
case Scalar:
|
||||||
|
htmlContent, err = generateScalar(config)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unsupported UI type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to generate UI: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, err = w.Write([]byte(htmlContent))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to write response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateData wraps UIConfig to properly handle CSS in templates
|
||||||
|
type templateData struct {
|
||||||
|
UIConfig
|
||||||
|
SafeCustomCSS template.CSS
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSwaggerUI generates the HTML for Swagger UI
|
||||||
|
func generateSwaggerUI(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; }
|
||||||
|
*, *:before, *:after { box-sizing: inherit; }
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui"></div>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-standalone-preset.js"></script>
|
||||||
|
<script>
|
||||||
|
window.onload = function() {
|
||||||
|
const ui = SwaggerUIBundle({
|
||||||
|
url: "{{.SpecURL}}",
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
deepLinking: true,
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
plugins: [
|
||||||
|
SwaggerUIBundle.plugins.DownloadUrl
|
||||||
|
],
|
||||||
|
layout: "StandaloneLayout",
|
||||||
|
{{if eq .Theme "dark"}}
|
||||||
|
syntaxHighlight: {
|
||||||
|
activate: true,
|
||||||
|
theme: "monokai"
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
});
|
||||||
|
window.ui = ui;
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("swagger").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRapiDoc generates the HTML for RapiDoc
|
||||||
|
func generateRapiDoc(config UIConfig) (string, error) {
|
||||||
|
theme := "light"
|
||||||
|
if config.Theme == "dark" {
|
||||||
|
theme = "dark"
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<rapi-doc
|
||||||
|
spec-url="{{.SpecURL}}"
|
||||||
|
theme="` + theme + `"
|
||||||
|
render-style="read"
|
||||||
|
show-header="true"
|
||||||
|
show-info="true"
|
||||||
|
allow-try="true"
|
||||||
|
allow-server-selection="true"
|
||||||
|
allow-authentication="true"
|
||||||
|
api-key-name="Authorization"
|
||||||
|
api-key-location="header"
|
||||||
|
></rapi-doc>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("rapidoc").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRedoc generates the HTML for Redoc
|
||||||
|
func generateRedoc(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<redoc spec-url="{{.SpecURL}}" {{if eq .Theme "dark"}}theme='{"colors": {"primary": {"main": "#dd5522"}}}'{{end}}></redoc>
|
||||||
|
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("redoc").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateScalar generates the HTML for Scalar
|
||||||
|
func generateScalar(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<script id="api-reference" data-url="{{.SpecURL}}" {{if eq .Theme "dark"}}data-theme="dark"{{end}}></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("scalar").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupUIRoute adds the OpenAPI UI route to a mux router
|
||||||
|
// This is a convenience function for the most common use case
|
||||||
|
func SetupUIRoute(router *mux.Router, path string, config UIConfig) {
|
||||||
|
router.Handle(path, UIHandler(config))
|
||||||
|
}
|
||||||
308
pkg/openapi/ui_handler_test.go
Normal file
308
pkg/openapi/ui_handler_test.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package openapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUIHandler_SwaggerUI(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "Test API Docs",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Swagger UI specific content
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected Swagger UI content")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "SwaggerUIBundle") {
|
||||||
|
t.Error("Expected SwaggerUIBundle script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "swagger-ui-dist") {
|
||||||
|
t.Error("Expected Swagger UI CDN link")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_RapiDoc(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: RapiDoc,
|
||||||
|
SpecURL: "/api/spec",
|
||||||
|
Title: "RapiDoc Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for RapiDoc specific content
|
||||||
|
if !strings.Contains(body, "rapi-doc") {
|
||||||
|
t.Error("Expected rapi-doc element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "rapidoc-min.js") {
|
||||||
|
t.Error("Expected RapiDoc script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Redoc(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: Redoc,
|
||||||
|
SpecURL: "/spec.json",
|
||||||
|
Title: "Redoc Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Redoc specific content
|
||||||
|
if !strings.Contains(body, "<redoc") {
|
||||||
|
t.Error("Expected redoc element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "redoc.standalone.js") {
|
||||||
|
t.Error("Expected Redoc script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Scalar(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: Scalar,
|
||||||
|
SpecURL: "/openapi.json",
|
||||||
|
Title: "Scalar Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Scalar specific content
|
||||||
|
if !strings.Contains(body, "api-reference") {
|
||||||
|
t.Error("Expected api-reference element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "@scalar/api-reference") {
|
||||||
|
t.Error("Expected Scalar script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_DefaultValues(t *testing.T) {
|
||||||
|
// Test with empty config to check defaults
|
||||||
|
config := UIConfig{}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Should default to Swagger UI
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected default to Swagger UI")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should default to /openapi spec URL
|
||||||
|
if !strings.Contains(body, "/openapi") {
|
||||||
|
t.Error("Expected default spec URL '/openapi'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should default to "API Documentation" title
|
||||||
|
if !strings.Contains(body, "API Documentation") {
|
||||||
|
t.Error("Expected default title 'API Documentation'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_CustomCSS(t *testing.T) {
|
||||||
|
customCSS := ".custom-class { color: red; }"
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
CustomCSS: customCSS,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
if !strings.Contains(body, customCSS) {
|
||||||
|
t.Errorf("Expected custom CSS to be included. Body:\n%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Favicon(t *testing.T) {
|
||||||
|
faviconURL := "https://example.com/favicon.ico"
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
FaviconURL: faviconURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
if !strings.Contains(body, faviconURL) {
|
||||||
|
t.Error("Expected favicon URL to be included")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_DarkTheme(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
Theme: "dark",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// SwaggerUI uses monokai theme for dark mode
|
||||||
|
if !strings.Contains(body, "monokai") {
|
||||||
|
t.Error("Expected dark theme configuration for Swagger UI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_InvalidUIType(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: "invalid-ui-type",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected status 400 for invalid UI type, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_ContentType(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if !strings.Contains(contentType, "text/html") {
|
||||||
|
t.Errorf("Expected Content-Type to contain 'text/html', got '%s'", contentType)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentType, "charset=utf-8") {
|
||||||
|
t.Errorf("Expected Content-Type to contain 'charset=utf-8', got '%s'", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupUIRoute(t *testing.T) {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
}
|
||||||
|
|
||||||
|
SetupUIRoute(router, "/api-docs", config)
|
||||||
|
|
||||||
|
// Test that the route was added and works
|
||||||
|
req := httptest.NewRequest("GET", "/api-docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it returns HTML
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected Swagger UI content")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
package reflection
|
package reflection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
)
|
)
|
||||||
@@ -897,6 +899,368 @@ func GetRelationModel(model interface{}, fieldName string) interface{} {
|
|||||||
return currentModel
|
return currentModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MapToStruct populates a struct from a map while preserving custom types
|
||||||
|
// It uses reflection to set struct fields based on map keys, matching by:
|
||||||
|
// 1. Bun tag column name
|
||||||
|
// 2. Gorm tag column name
|
||||||
|
// 3. JSON tag name
|
||||||
|
// 4. Field name (case-insensitive)
|
||||||
|
// This preserves custom types that implement driver.Valuer like SqlJSONB
|
||||||
|
func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||||
|
if dataMap == nil || target == nil {
|
||||||
|
return fmt.Errorf("dataMap and target cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetValue := reflect.ValueOf(target)
|
||||||
|
if targetValue.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("target must be a pointer to a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetValue = targetValue.Elem()
|
||||||
|
if targetValue.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("target must be a pointer to a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetType := targetValue.Type()
|
||||||
|
|
||||||
|
// Create a map of column names to field indices for faster lookup
|
||||||
|
columnToField := make(map[string]int)
|
||||||
|
for i := 0; i < targetType.NumField(); i++ {
|
||||||
|
field := targetType.Field(i)
|
||||||
|
|
||||||
|
// Skip unexported fields
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build list of possible column names for this field
|
||||||
|
var columnNames []string
|
||||||
|
|
||||||
|
// 1. Bun tag
|
||||||
|
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||||
|
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
columnNames = append(columnNames, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Gorm tag
|
||||||
|
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||||
|
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
columnNames = append(columnNames, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. JSON tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
columnNames = append(columnNames, parts[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Field name variations
|
||||||
|
columnNames = append(columnNames, field.Name)
|
||||||
|
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||||
|
columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||||
|
|
||||||
|
// Map all column name variations to this field index
|
||||||
|
for _, colName := range columnNames {
|
||||||
|
columnToField[strings.ToLower(colName)] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through the map and set struct fields
|
||||||
|
for key, value := range dataMap {
|
||||||
|
// Find the field index for this key
|
||||||
|
fieldIndex, found := columnToField[strings.ToLower(key)]
|
||||||
|
if !found {
|
||||||
|
// Skip keys that don't map to any field
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
field := targetValue.Field(fieldIndex)
|
||||||
|
if !field.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the value, preserving custom types
|
||||||
|
if err := setFieldValue(field, value); err != nil {
|
||||||
|
return fmt.Errorf("failed to set field %s: %w", targetType.Field(fieldIndex).Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setFieldValue sets a reflect.Value from an interface{} value, handling type conversions
|
||||||
|
func setFieldValue(field reflect.Value, value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
// Set zero value for nil
|
||||||
|
field.Set(reflect.Zero(field.Type()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valueReflect := reflect.ValueOf(value)
|
||||||
|
|
||||||
|
// If types match exactly, just set it
|
||||||
|
if valueReflect.Type().AssignableTo(field.Type()) {
|
||||||
|
field.Set(valueReflect)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointer fields
|
||||||
|
if field.Kind() == reflect.Ptr {
|
||||||
|
if valueReflect.Kind() != reflect.Ptr {
|
||||||
|
// Create a new pointer and set its value
|
||||||
|
newPtr := reflect.New(field.Type().Elem())
|
||||||
|
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
field.Set(newPtr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle conversions for basic types
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
field.SetString(str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
if num, ok := convertToInt64(value); ok {
|
||||||
|
if field.OverflowInt(num) {
|
||||||
|
return fmt.Errorf("integer overflow")
|
||||||
|
}
|
||||||
|
field.SetInt(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if num, ok := convertToUint64(value); ok {
|
||||||
|
if field.OverflowUint(num) {
|
||||||
|
return fmt.Errorf("unsigned integer overflow")
|
||||||
|
}
|
||||||
|
field.SetUint(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
if num, ok := convertToFloat64(value); ok {
|
||||||
|
if field.OverflowFloat(num) {
|
||||||
|
return fmt.Errorf("float overflow")
|
||||||
|
}
|
||||||
|
field.SetFloat(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
if b, ok := value.(bool); ok {
|
||||||
|
field.SetBool(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
// Handle []byte specially (for types like SqlJSONB)
|
||||||
|
if field.Type().Elem().Kind() == reflect.Uint8 {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []byte:
|
||||||
|
field.SetBytes(v)
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
field.SetBytes([]byte(v))
|
||||||
|
return nil
|
||||||
|
case map[string]interface{}, []interface{}:
|
||||||
|
// Marshal complex types to JSON for SqlJSONB fields
|
||||||
|
jsonBytes, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal value to JSON: %w", err)
|
||||||
|
}
|
||||||
|
field.SetBytes(jsonBytes)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||||
|
if field.Kind() == reflect.Struct {
|
||||||
|
|
||||||
|
// Handle datatypes.SqlNull[T] and wrapped types (SqlTimeStamp, SqlDate, SqlTime)
|
||||||
|
// Check if the type has a Scan method (sql.Scanner interface)
|
||||||
|
if field.CanAddr() {
|
||||||
|
scanMethod := field.Addr().MethodByName("Scan")
|
||||||
|
if scanMethod.IsValid() {
|
||||||
|
// Call the Scan method with the value
|
||||||
|
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
||||||
|
if len(results) > 0 {
|
||||||
|
// Check if there was an error
|
||||||
|
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle time.Time with ISO string fallback
|
||||||
|
if field.Type() == reflect.TypeOf(time.Time{}) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case time.Time:
|
||||||
|
field.Set(reflect.ValueOf(v))
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
// Try parsing as ISO 8601 / RFC3339
|
||||||
|
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||||
|
field.Set(reflect.ValueOf(t))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Try other common formats
|
||||||
|
formats := []string{
|
||||||
|
"2006-01-02T15:04:05.000-0700",
|
||||||
|
"2006-01-02T15:04:05.000",
|
||||||
|
"2006-01-02T15:04:05",
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
"2006-01-02",
|
||||||
|
}
|
||||||
|
for _, format := range formats {
|
||||||
|
if t, err := time.Parse(format, v); err == nil {
|
||||||
|
field.Set(reflect.ValueOf(t))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot parse time string: %s", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: Try to find a "Val" field (for SqlNull types) and set it directly
|
||||||
|
valField := field.FieldByName("Val")
|
||||||
|
if valField.IsValid() && valField.CanSet() {
|
||||||
|
// Also set Valid field to true
|
||||||
|
validField := field.FieldByName("Valid")
|
||||||
|
if validField.IsValid() && validField.CanSet() && validField.Kind() == reflect.Bool {
|
||||||
|
// Set the Val field
|
||||||
|
if err := setFieldValue(valField, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Set Valid to true
|
||||||
|
validField.SetBool(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can convert the type, do it
|
||||||
|
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||||
|
field.Set(valueReflect.Convert(field.Type()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToInt64 attempts to convert various types to int64
|
||||||
|
func convertToInt64(value interface{}) (int64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(v), true
|
||||||
|
case int8:
|
||||||
|
return int64(v), true
|
||||||
|
case int16:
|
||||||
|
return int64(v), true
|
||||||
|
case int32:
|
||||||
|
return int64(v), true
|
||||||
|
case int64:
|
||||||
|
return v, true
|
||||||
|
case uint:
|
||||||
|
return int64(v), true
|
||||||
|
case uint8:
|
||||||
|
return int64(v), true
|
||||||
|
case uint16:
|
||||||
|
return int64(v), true
|
||||||
|
case uint32:
|
||||||
|
return int64(v), true
|
||||||
|
case uint64:
|
||||||
|
return int64(v), true
|
||||||
|
case float32:
|
||||||
|
return int64(v), true
|
||||||
|
case float64:
|
||||||
|
return int64(v), true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToUint64 attempts to convert various types to uint64
|
||||||
|
func convertToUint64(value interface{}) (uint64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return uint64(v), true
|
||||||
|
case int8:
|
||||||
|
return uint64(v), true
|
||||||
|
case int16:
|
||||||
|
return uint64(v), true
|
||||||
|
case int32:
|
||||||
|
return uint64(v), true
|
||||||
|
case int64:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint8:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint16:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint32:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint64:
|
||||||
|
return v, true
|
||||||
|
case float32:
|
||||||
|
return uint64(v), true
|
||||||
|
case float64:
|
||||||
|
return uint64(v), true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseUint(v, 10, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToFloat64 attempts to convert various types to float64
|
||||||
|
func convertToFloat64(value interface{}) (float64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return float64(v), true
|
||||||
|
case int8:
|
||||||
|
return float64(v), true
|
||||||
|
case int16:
|
||||||
|
return float64(v), true
|
||||||
|
case int32:
|
||||||
|
return float64(v), true
|
||||||
|
case int64:
|
||||||
|
return float64(v), true
|
||||||
|
case uint:
|
||||||
|
return float64(v), true
|
||||||
|
case uint8:
|
||||||
|
return float64(v), true
|
||||||
|
case uint16:
|
||||||
|
return float64(v), true
|
||||||
|
case uint32:
|
||||||
|
return float64(v), true
|
||||||
|
case uint64:
|
||||||
|
return float64(v), true
|
||||||
|
case float32:
|
||||||
|
return float64(v), true
|
||||||
|
case float64:
|
||||||
|
return v, true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||||
|
|||||||
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
package reflection_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlJSONB_PreservesDriverValuer(t *testing.T) {
|
||||||
|
// Test that SqlJSONB type preserves driver.Valuer interface
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(123),
|
||||||
|
"meta": map[string]interface{}{
|
||||||
|
"key": "value",
|
||||||
|
"num": 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the field was set
|
||||||
|
if result.ID != 123 {
|
||||||
|
t.Errorf("ID = %v, want 123", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB was populated
|
||||||
|
if len(result.Meta) == 0 {
|
||||||
|
t.Error("Meta is empty, want non-empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most importantly: verify driver.Valuer interface works
|
||||||
|
value, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value should return a string representation of the JSON
|
||||||
|
if value == nil {
|
||||||
|
t.Error("Meta.Value() returned nil, want non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check it's a valid JSON string
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
if len(str) == 0 {
|
||||||
|
t.Error("Meta.Value() returned empty string, want valid JSON")
|
||||||
|
}
|
||||||
|
t.Logf("SqlJSONB.Value() returned: %s", str)
|
||||||
|
} else {
|
||||||
|
t.Errorf("Meta.Value() returned type %T, want string", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlJSONB_FromBytes(t *testing.T) {
|
||||||
|
// Test that SqlJSONB can be set from []byte directly
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes := []byte(`{"direct":"bytes"}`)
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(456),
|
||||||
|
"meta": jsonBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ID != 456 {
|
||||||
|
t.Errorf("ID = %v, want 456", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Meta) != string(jsonBytes) {
|
||||||
|
t.Errorf("Meta = %s, want %s", string(result.Meta), string(jsonBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer works
|
||||||
|
value, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if value == nil {
|
||||||
|
t.Error("Meta.Value() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_AllSqlTypes(t *testing.T) {
|
||||||
|
// Test model with all SQL custom types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||||
|
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||||
|
LoginTime spectypes.SqlTime `bun:"login_time" json:"login_time"`
|
||||||
|
Meta spectypes.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
Tags spectypes.SqlJSONB `bun:"tags" json:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
birthDate := time.Date(1990, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||||
|
loginTime := time.Date(0, 1, 1, 14, 30, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(100),
|
||||||
|
"name": "Test User",
|
||||||
|
"created_at": now,
|
||||||
|
"birth_date": birthDate,
|
||||||
|
"login_time": loginTime,
|
||||||
|
"meta": map[string]interface{}{
|
||||||
|
"role": "admin",
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
"tags": []interface{}{"golang", "testing", "sql"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify basic fields
|
||||||
|
if result.ID != 100 {
|
||||||
|
t.Errorf("ID = %v, want 100", result.ID)
|
||||||
|
}
|
||||||
|
if result.Name != "Test User" {
|
||||||
|
t.Errorf("Name = %v, want 'Test User'", result.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlTimeStamp
|
||||||
|
if !result.CreatedAt.Valid {
|
||||||
|
t.Error("CreatedAt.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.CreatedAt.Val.Equal(now) {
|
||||||
|
t.Errorf("CreatedAt.Val = %v, want %v", result.CreatedAt.Val, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlTimeStamp
|
||||||
|
tsValue, err := result.CreatedAt.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("CreatedAt.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if tsValue == nil {
|
||||||
|
t.Error("CreatedAt.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlDate
|
||||||
|
if !result.BirthDate.Valid {
|
||||||
|
t.Error("BirthDate.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.BirthDate.Val.Equal(birthDate) {
|
||||||
|
t.Errorf("BirthDate.Val = %v, want %v", result.BirthDate.Val, birthDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlDate
|
||||||
|
dateValue, err := result.BirthDate.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("BirthDate.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if dateValue == nil {
|
||||||
|
t.Error("BirthDate.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlTime
|
||||||
|
if !result.LoginTime.Valid {
|
||||||
|
t.Error("LoginTime.Valid = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlTime
|
||||||
|
timeValue, err := result.LoginTime.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("LoginTime.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if timeValue == nil {
|
||||||
|
t.Error("LoginTime.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB for Meta
|
||||||
|
if len(result.Meta) == 0 {
|
||||||
|
t.Error("Meta is empty")
|
||||||
|
}
|
||||||
|
metaValue, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if metaValue == nil {
|
||||||
|
t.Error("Meta.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB for Tags
|
||||||
|
if len(result.Tags) == 0 {
|
||||||
|
t.Error("Tags is empty")
|
||||||
|
}
|
||||||
|
tagsValue, err := result.Tags.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Tags.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if tagsValue == nil {
|
||||||
|
t.Error("Tags.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("All SQL types successfully preserved driver.Valuer interface:")
|
||||||
|
t.Logf(" - SqlTimeStamp: %v", tsValue)
|
||||||
|
t.Logf(" - SqlDate: %v", dateValue)
|
||||||
|
t.Logf(" - SqlTime: %v", timeValue)
|
||||||
|
t.Logf(" - SqlJSONB (Meta): %v", metaValue)
|
||||||
|
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
|
||||||
|
// Test that SqlNull types handle nil values correctly
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
UpdatedAt spectypes.SqlTimeStamp `bun:"updated_at" json:"updated_at"`
|
||||||
|
DeletedAt spectypes.SqlTimeStamp `bun:"deleted_at" json:"deleted_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(200),
|
||||||
|
"updated_at": now,
|
||||||
|
"deleted_at": nil, // Explicitly nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAt should be valid
|
||||||
|
if !result.UpdatedAt.Valid {
|
||||||
|
t.Error("UpdatedAt.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.UpdatedAt.Val.Equal(now) {
|
||||||
|
t.Errorf("UpdatedAt.Val = %v, want %v", result.UpdatedAt.Val, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAt should be invalid (null)
|
||||||
|
if result.DeletedAt.Valid {
|
||||||
|
t.Error("DeletedAt.Valid = true, want false (null)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for null SqlTimeStamp
|
||||||
|
deletedValue, err := result.DeletedAt.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("DeletedAt.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if deletedValue != nil {
|
||||||
|
t.Errorf("DeletedAt.Value() = %v, want nil", deletedValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1687,3 +1687,201 @@ func TestGetRelationModel_WithTags(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct(t *testing.T) {
|
||||||
|
// Test model with various field types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Age int `bun:"age" json:"age"`
|
||||||
|
Active bool `bun:"active" json:"active"`
|
||||||
|
Score float64 `bun:"score" json:"score"`
|
||||||
|
Data []byte `bun:"data" json:"data"`
|
||||||
|
MetaJSON []byte `bun:"meta_json" json:"meta_json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataMap map[string]interface{}
|
||||||
|
expected TestModel
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic types conversion",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(123),
|
||||||
|
"name": "Test User",
|
||||||
|
"age": 30,
|
||||||
|
"active": true,
|
||||||
|
"score": 95.5,
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 123,
|
||||||
|
Name: "Test User",
|
||||||
|
Age: 30,
|
||||||
|
Active: true,
|
||||||
|
Score: 95.5,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice (SqlJSONB-like) from []byte",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(456),
|
||||||
|
"name": "JSON Test",
|
||||||
|
"data": []byte(`{"key":"value"}`),
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 456,
|
||||||
|
Name: "JSON Test",
|
||||||
|
Data: []byte(`{"key":"value"}`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(789),
|
||||||
|
"data": "string data",
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 789,
|
||||||
|
Data: []byte("string data"),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from map (JSON marshal)",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(999),
|
||||||
|
"meta_json": map[string]interface{}{
|
||||||
|
"field1": "value1",
|
||||||
|
"field2": 42,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 999,
|
||||||
|
MetaJSON: []byte(`{"field1":"value1","field2":42}`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from slice (JSON marshal)",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(111),
|
||||||
|
"meta_json": []interface{}{"item1", "item2", 3},
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 111,
|
||||||
|
MetaJSON: []byte(`["item1","item2",3]`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Field matching by bun tag",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(222),
|
||||||
|
"name": "Tagged Field",
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 222,
|
||||||
|
Name: "Tagged Field",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil values",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(333),
|
||||||
|
"data": nil,
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 333,
|
||||||
|
Data: nil,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var result TestModel
|
||||||
|
err := MapToStruct(tt.dataMap, &result)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare fields individually for better error messages
|
||||||
|
if result.ID != tt.expected.ID {
|
||||||
|
t.Errorf("ID = %v, want %v", result.ID, tt.expected.ID)
|
||||||
|
}
|
||||||
|
if result.Name != tt.expected.Name {
|
||||||
|
t.Errorf("Name = %v, want %v", result.Name, tt.expected.Name)
|
||||||
|
}
|
||||||
|
if result.Age != tt.expected.Age {
|
||||||
|
t.Errorf("Age = %v, want %v", result.Age, tt.expected.Age)
|
||||||
|
}
|
||||||
|
if result.Active != tt.expected.Active {
|
||||||
|
t.Errorf("Active = %v, want %v", result.Active, tt.expected.Active)
|
||||||
|
}
|
||||||
|
if result.Score != tt.expected.Score {
|
||||||
|
t.Errorf("Score = %v, want %v", result.Score, tt.expected.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For byte slices, compare as strings for JSON data
|
||||||
|
if tt.expected.Data != nil {
|
||||||
|
if string(result.Data) != string(tt.expected.Data) {
|
||||||
|
t.Errorf("Data = %s, want %s", string(result.Data), string(tt.expected.Data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tt.expected.MetaJSON != nil {
|
||||||
|
if string(result.MetaJSON) != string(tt.expected.MetaJSON) {
|
||||||
|
t.Errorf("MetaJSON = %s, want %s", string(result.MetaJSON), string(tt.expected.MetaJSON))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_Errors(t *testing.T) {
|
||||||
|
type TestModel struct {
|
||||||
|
ID int `bun:"id" json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataMap map[string]interface{}
|
||||||
|
target interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Nil dataMap",
|
||||||
|
dataMap: nil,
|
||||||
|
target: &TestModel{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil target",
|
||||||
|
dataMap: map[string]interface{}{"id": 1},
|
||||||
|
target: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-pointer target",
|
||||||
|
dataMap: map[string]interface{}{"id": 1},
|
||||||
|
target: TestModel{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := MapToStruct(tt.dataMap, tt.target)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
118
pkg/resolvespec/cache_helpers.go
Normal file
118
pkg/resolvespec/cache_helpers.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package resolvespec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// queryCacheKey represents the components used to build a cache key for query total count
|
||||||
|
type queryCacheKey struct {
|
||||||
|
TableName string `json:"table_name"`
|
||||||
|
Filters []common.FilterOption `json:"filters"`
|
||||||
|
Sort []common.SortOption `json:"sort"`
|
||||||
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
|
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedTotal represents a cached total count
|
||||||
|
type cachedTotal struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||||
|
func buildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||||
|
key := queryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildExtendedQueryCacheKey builds a cache key for extended query options with cursor pagination
|
||||||
|
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
|
customWhere, customOr string, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
|
key := queryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
CursorForward: cursorFwd,
|
||||||
|
CursorBackward: cursorBwd,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%s_%s",
|
||||||
|
tableName, filters, sort, customWhere, customOr, cursorFwd, cursorBwd))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashString computes SHA256 hash of a string
|
||||||
|
func hashString(s string) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||||
|
func getQueryTotalCacheKey(hash string) string {
|
||||||
|
return fmt.Sprintf("query_total:%s", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCacheTags creates cache tags from schema and table name
|
||||||
|
func buildCacheTags(schema, tableName string) []string {
|
||||||
|
return []string{
|
||||||
|
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||||
|
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||||
|
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
cacheData := cachedTotal{Total: total}
|
||||||
|
tags := buildCacheTags(schema, tableName)
|
||||||
|
|
||||||
|
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||||
|
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
|
||||||
|
// Invalidate for each tag
|
||||||
|
for _, tag := range tags {
|
||||||
|
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package resolvespec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -330,19 +331,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Use extended cache key if cursors are present
|
// Use extended cache key if cursors are present
|
||||||
var cacheKeyHash string
|
var cacheKeyHash string
|
||||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||||
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
|
cacheKeyHash = buildExtendedQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
"", // No custom SQL WHERE in resolvespec
|
"", // No custom SQL WHERE in resolvespec
|
||||||
"", // No custom SQL OR in resolvespec
|
"", // No custom SQL OR in resolvespec
|
||||||
nil, // No expand options in resolvespec
|
|
||||||
false, // distinct not used here
|
|
||||||
options.CursorForward,
|
options.CursorForward,
|
||||||
options.CursorBackward,
|
options.CursorBackward,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
cacheKeyHash = cache.BuildQueryCacheKey(
|
cacheKeyHash = buildQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
@@ -350,10 +349,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
"", // No custom SQL OR in resolvespec
|
"", // No custom SQL OR in resolvespec
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
|
cacheKey := getQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
// Try to retrieve from cache
|
// Try to retrieve from cache
|
||||||
var cachedTotal cache.CachedTotal
|
var cachedTotal cachedTotal
|
||||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total = cachedTotal.Total
|
total = cachedTotal.Total
|
||||||
@@ -370,10 +369,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
total = count
|
total = count
|
||||||
logger.Debug("Total records (from query): %d", total)
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
// Store in cache
|
// Store in cache with schema and table tags
|
||||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
cacheData := cache.CachedTotal{Total: total}
|
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
|
||||||
logger.Warn("Failed to cache query total: %v", err)
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
// Don't fail the request if caching fails
|
// Don't fail the request if caching fails
|
||||||
} else {
|
} else {
|
||||||
@@ -463,6 +461,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, result.Data, nil)
|
h.sendResponse(w, result.Data, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -479,6 +482,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, v, nil)
|
h.sendResponse(w, v, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
@@ -517,6 +525,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
logger.Info("Successfully created %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -540,6 +553,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records", len(v))
|
logger.Info("Successfully created %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, v, nil)
|
h.sendResponse(w, v, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
@@ -583,6 +601,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
logger.Info("Successfully created %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -610,6 +633,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records", len(v))
|
logger.Info("Successfully created %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, list, nil)
|
h.sendResponse(w, list, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -660,6 +688,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, result.Data, nil)
|
h.sendResponse(w, result.Data, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -696,6 +729,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
logger.Info("Successfully updated %d records", result.RowsAffected())
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, data, nil)
|
h.sendResponse(w, data, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
@@ -734,6 +772,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -757,6 +800,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(updates))
|
logger.Info("Successfully updated %d records", len(updates))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, updates, nil)
|
h.sendResponse(w, updates, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
@@ -799,6 +847,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -826,6 +879,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(list))
|
logger.Info("Successfully updated %d records", len(list))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, list, nil)
|
h.sendResponse(w, list, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -872,6 +930,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", len(v))
|
logger.Info("Successfully deleted %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -913,6 +976,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -939,6 +1007,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -957,7 +1030,29 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
// Get primary key name
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// First, fetch the record that will be deleted
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
recordToDelete := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
|
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
logger.Warn("Record not found for delete: %s = %s", pkName, id)
|
||||||
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error("Error fetching record for delete: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Error fetching record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -966,14 +1061,21 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the record was actually deleted
|
||||||
if result.RowsAffected() == 0 {
|
if result.RowsAffected() == 0 {
|
||||||
logger.Warn("No record found to delete with ID: %s", id)
|
logger.Warn("No rows deleted for ID: %s", id)
|
||||||
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil)
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found or already deleted", nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully deleted record with ID: %s", id)
|
logger.Info("Successfully deleted record with ID: %s", id)
|
||||||
h.sendResponse(w, nil, nil)
|
// Return the deleted record data
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
h.sendResponse(w, recordToDelete, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ type HookContext struct {
|
|||||||
Abort bool // If set to true, the operation will be aborted
|
Abort bool // If set to true, the operation will be aborted
|
||||||
AbortMessage string // Message to return if aborted
|
AbortMessage string // Message to return if aborted
|
||||||
AbortCode int // HTTP status code if aborted
|
AbortCode int // HTTP status code if aborted
|
||||||
|
|
||||||
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
|
Tx common.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// HookFunc is the signature for hook functions
|
// HookFunc is the signature for hook functions
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package cache
|
package restheadspec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,56 +7,42 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryCacheKey represents the components used to build a cache key for query total count
|
// expandOptionKey represents expand options for cache key
|
||||||
type QueryCacheKey struct {
|
type expandOptionKey struct {
|
||||||
|
Relation string `json:"relation"`
|
||||||
|
Where string `json:"where,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryCacheKey represents the components used to build a cache key for query total count
|
||||||
|
type queryCacheKey struct {
|
||||||
TableName string `json:"table_name"`
|
TableName string `json:"table_name"`
|
||||||
Filters []common.FilterOption `json:"filters"`
|
Filters []common.FilterOption `json:"filters"`
|
||||||
Sort []common.SortOption `json:"sort"`
|
Sort []common.SortOption `json:"sort"`
|
||||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
Expand []ExpandOptionKey `json:"expand,omitempty"`
|
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||||
Distinct bool `json:"distinct,omitempty"`
|
Distinct bool `json:"distinct,omitempty"`
|
||||||
CursorForward string `json:"cursor_forward,omitempty"`
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
CursorBackward string `json:"cursor_backward,omitempty"`
|
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpandOptionKey represents expand options for cache key
|
// cachedTotal represents a cached total count
|
||||||
type ExpandOptionKey struct {
|
type cachedTotal struct {
|
||||||
Relation string `json:"relation"`
|
Total int `json:"total"`
|
||||||
Where string `json:"where,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
|
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||||
// This is used to cache the total count of records matching a query
|
|
||||||
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
|
||||||
key := QueryCacheKey{
|
|
||||||
TableName: tableName,
|
|
||||||
Filters: filters,
|
|
||||||
Sort: sort,
|
|
||||||
CustomSQLWhere: customWhere,
|
|
||||||
CustomSQLOr: customOr,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize to JSON for consistent hashing
|
|
||||||
jsonData, err := json.Marshal(key)
|
|
||||||
if err != nil {
|
|
||||||
// Fallback to simple string concatenation if JSON fails
|
|
||||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
|
||||||
}
|
|
||||||
|
|
||||||
return hashString(string(jsonData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
|
||||||
// Includes expand, distinct, and cursor pagination options
|
// Includes expand, distinct, and cursor pagination options
|
||||||
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
key := QueryCacheKey{
|
key := queryCacheKey{
|
||||||
TableName: tableName,
|
TableName: tableName,
|
||||||
Filters: filters,
|
Filters: filters,
|
||||||
Sort: sort,
|
Sort: sort,
|
||||||
@@ -69,11 +55,11 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
|
|
||||||
// Convert expand options to cache key format
|
// Convert expand options to cache key format
|
||||||
if len(expandOpts) > 0 {
|
if len(expandOpts) > 0 {
|
||||||
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
|
key.Expand = make([]expandOptionKey, 0, len(expandOpts))
|
||||||
for _, exp := range expandOpts {
|
for _, exp := range expandOpts {
|
||||||
// Type assert to get the expand option fields we care about for caching
|
// Type assert to get the expand option fields we care about for caching
|
||||||
if expMap, ok := exp.(map[string]interface{}); ok {
|
if expMap, ok := exp.(map[string]interface{}); ok {
|
||||||
expKey := ExpandOptionKey{}
|
expKey := expandOptionKey{}
|
||||||
if rel, ok := expMap["relation"].(string); ok {
|
if rel, ok := expMap["relation"].(string); ok {
|
||||||
expKey.Relation = rel
|
expKey.Relation = rel
|
||||||
}
|
}
|
||||||
@@ -83,7 +69,6 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
key.Expand = append(key.Expand, expKey)
|
key.Expand = append(key.Expand, expKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sort expand options for consistent hashing (already sorted by relation name above)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize to JSON for consistent hashing
|
// Serialize to JSON for consistent hashing
|
||||||
@@ -104,24 +89,38 @@ func hashString(s string) string {
|
|||||||
return hex.EncodeToString(h.Sum(nil))
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||||
func GetQueryTotalCacheKey(hash string) string {
|
func getQueryTotalCacheKey(hash string) string {
|
||||||
return fmt.Sprintf("query_total:%s", hash)
|
return fmt.Sprintf("query_total:%s", hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CachedTotal represents a cached total count
|
// buildCacheTags creates cache tags from schema and table name
|
||||||
type CachedTotal struct {
|
func buildCacheTags(schema, tableName string) []string {
|
||||||
Total int `json:"total"`
|
return []string{
|
||||||
|
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||||
|
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateCacheForTable removes all cached totals for a specific table
|
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||||
// This should be called when data in the table changes (insert/update/delete)
|
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||||
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
|
c := cache.GetDefaultCache()
|
||||||
cache := GetDefaultCache()
|
cacheData := cachedTotal{Total: total}
|
||||||
|
tags := buildCacheTags(schema, tableName)
|
||||||
|
|
||||||
// Build a pattern to match all query totals for this table
|
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||||
// Note: This requires pattern matching support in the provider
|
}
|
||||||
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
|
|
||||||
|
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||||
return cache.DeleteByPattern(ctx, pattern)
|
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
|
||||||
|
// Invalidate for each tag
|
||||||
|
for _, tag := range tags {
|
||||||
|
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
@@ -2,6 +2,7 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -300,6 +301,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Options: options,
|
Options: options,
|
||||||
ID: id,
|
ID: id,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||||
@@ -480,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// First add table prefixes to unqualified columns (but skip columns inside function calls)
|
||||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||||
|
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedWhere != "" {
|
if sanitizedWhere != "" {
|
||||||
query = query.Where(sanitizedWhere)
|
query = query.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -490,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (OR condition)
|
// Apply custom SQL WHERE clause (OR condition)
|
||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
|
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedOr != "" {
|
if sanitizedOr != "" {
|
||||||
query = query.WhereOr(sanitizedOr)
|
query = query.WhereOr(sanitizedOr)
|
||||||
}
|
}
|
||||||
@@ -512,14 +517,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
direction = "DESC"
|
direction = "DESC"
|
||||||
}
|
}
|
||||||
logger.Debug("Applying sort: %s %s", sort.Column, direction)
|
logger.Debug("Applying sort: %s %s", sort.Column, direction)
|
||||||
|
|
||||||
|
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||||
|
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
|
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||||
|
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
|
} else {
|
||||||
|
// Regular column - let Bun handle quoting
|
||||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get total count before pagination (unless skip count is requested)
|
// Get total count before pagination (unless skip count is requested)
|
||||||
var total int
|
var total int
|
||||||
if !options.SkipCount {
|
if !options.SkipCount {
|
||||||
// Try to get from cache first (unless SkipCache is true)
|
// Try to get from cache first (unless SkipCache is true)
|
||||||
var cachedTotal *cache.CachedTotal
|
var cachedTotalData *cachedTotal
|
||||||
var cacheKey string
|
var cacheKey string
|
||||||
|
|
||||||
if !options.SkipCache {
|
if !options.SkipCache {
|
||||||
@@ -533,7 +546,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
|
cacheKeyHash := buildExtendedQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
@@ -544,22 +557,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
options.CursorForward,
|
options.CursorForward,
|
||||||
options.CursorBackward,
|
options.CursorBackward,
|
||||||
)
|
)
|
||||||
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
|
cacheKey = getQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
// Try to retrieve from cache
|
// Try to retrieve from cache
|
||||||
cachedTotal = &cache.CachedTotal{}
|
cachedTotalData = &cachedTotal{}
|
||||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotalData)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total = cachedTotal.Total
|
total = cachedTotalData.Total
|
||||||
logger.Debug("Total records (from cache): %d", total)
|
logger.Debug("Total records (from cache): %d", total)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Cache miss for query total")
|
logger.Debug("Cache miss for query total")
|
||||||
cachedTotal = nil
|
cachedTotalData = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not in cache or cache skip, execute count query
|
// If not in cache or cache skip, execute count query
|
||||||
if cachedTotal == nil {
|
if cachedTotalData == nil {
|
||||||
count, err := query.Count(ctx)
|
count, err := query.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error counting records: %v", err)
|
logger.Error("Error counting records: %v", err)
|
||||||
@@ -569,11 +582,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
total = count
|
total = count
|
||||||
logger.Debug("Total records (from query): %d", total)
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
// Store in cache (if caching is enabled)
|
// Store in cache with schema and table tags (if caching is enabled)
|
||||||
if !options.SkipCache && cacheKey != "" {
|
if !options.SkipCache && cacheKey != "" {
|
||||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
cacheData := &cache.CachedTotal{Total: total}
|
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
|
||||||
logger.Warn("Failed to cache query total: %v", err)
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
// Don't fail the request if caching fails
|
// Don't fail the request if caching fails
|
||||||
} else {
|
} else {
|
||||||
@@ -652,6 +664,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if a specific ID was requested but no record was found
|
||||||
|
resultCount := reflection.Len(modelPtr)
|
||||||
|
if id != "" && resultCount == 0 {
|
||||||
|
logger.Warn("Record not found for ID: %s", id)
|
||||||
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
limit := 0
|
limit := 0
|
||||||
if options.Limit != nil {
|
if options.Limit != nil {
|
||||||
limit = *options.Limit
|
limit = *options.Limit
|
||||||
@@ -666,7 +686,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
metadata := &common.Metadata{
|
metadata := &common.Metadata{
|
||||||
Total: int64(total),
|
Total: int64(total),
|
||||||
Count: int64(reflection.Len(modelPtr)),
|
Count: int64(resultCount),
|
||||||
Filtered: int64(total),
|
Filtered: int64(total),
|
||||||
Limit: limit,
|
Limit: limit,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
@@ -745,9 +765,42 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(preload.ComputedQL) > 0 {
|
if len(preload.ComputedQL) > 0 {
|
||||||
|
// Get the base table name from the related model
|
||||||
|
baseTableName := getTableNameFromModel(relatedModel)
|
||||||
|
|
||||||
|
// Convert the preload relation path to the appropriate alias format
|
||||||
|
// This is ORM-specific. Currently we only support Bun's format.
|
||||||
|
// TODO: Add support for other ORMs if needed
|
||||||
|
preloadAlias := ""
|
||||||
|
if h.db.GetUnderlyingDB() != nil {
|
||||||
|
// Check if we're using Bun by checking the type name
|
||||||
|
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
||||||
|
if strings.Contains(underlyingType, "bun.DB") {
|
||||||
|
// Use Bun's alias format: lowercase with double underscores
|
||||||
|
preloadAlias = relationPathToBunAlias(preload.Relation)
|
||||||
|
}
|
||||||
|
// For GORM: GORM doesn't use the same alias format, and this fix
|
||||||
|
// may not be needed since GORM handles preloads differently
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)",
|
||||||
|
preload.Relation, preloadAlias, baseTableName)
|
||||||
|
|
||||||
for colName, colExpr := range preload.ComputedQL {
|
for colName, colExpr := range preload.ComputedQL {
|
||||||
|
// Replace table references in the expression with the preload alias
|
||||||
|
// This fixes the ambiguous column reference issue when there are multiple
|
||||||
|
// levels of recursive/nested preloads
|
||||||
|
adjustedExpr := colExpr
|
||||||
|
if baseTableName != "" && preloadAlias != "" {
|
||||||
|
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||||
|
if adjustedExpr != colExpr {
|
||||||
|
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||||
|
colName, colExpr, adjustedExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName))
|
||||||
// Remove the computed column from selected columns to avoid duplication
|
// Remove the computed column from selected columns to avoid duplication
|
||||||
for colIndex := range preload.Columns {
|
for colIndex := range preload.Columns {
|
||||||
if preload.Columns[colIndex] == colName {
|
if preload.Columns[colIndex] == colName {
|
||||||
@@ -793,9 +846,16 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
// Apply sorting
|
// Apply sorting
|
||||||
if len(preload.Sort) > 0 {
|
if len(preload.Sort) > 0 {
|
||||||
for _, sort := range preload.Sort {
|
for _, sort := range preload.Sort {
|
||||||
|
// Check if it's an expression (enclosed in brackets) - use directly without quoting
|
||||||
|
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
|
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||||
|
sq = sq.OrderExpr(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||||
|
} else {
|
||||||
|
// Regular column - let ORM handle quoting
|
||||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply WHERE clause
|
// Apply WHERE clause
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
@@ -840,6 +900,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
||||||
|
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
||||||
|
func relationPathToBunAlias(relationPath string) string {
|
||||||
|
if relationPath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Convert to lowercase and replace dots with double underscores
|
||||||
|
alias := strings.ToLower(relationPath)
|
||||||
|
alias = strings.ReplaceAll(alias, ".", "__")
|
||||||
|
return alias
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||||
|
// with the appropriate alias for the current preload level
|
||||||
|
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||||
|
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||||
|
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||||
|
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||||
|
return sqlExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace both quoted and unquoted table references
|
||||||
|
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||||
|
|
||||||
|
// Pattern 1: tablename.column (unquoted)
|
||||||
|
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||||
|
|
||||||
|
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||||
|
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableNameFromModel extracts the table name from a model
|
||||||
|
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
||||||
|
func getTableNameFromModel(model interface{}) string {
|
||||||
|
if model == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for bun tag on embedded BaseModel
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if field.Anonymous {
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.HasPrefix(bunTag, "table:") {
|
||||||
|
return strings.TrimPrefix(bunTag, "table:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||||
|
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||||
|
return strings.ToLower(modelType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -866,6 +993,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
Options: options,
|
Options: options,
|
||||||
Data: data,
|
Data: data,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||||
@@ -955,6 +1083,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
Data: modelValue,
|
Data: modelValue,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
Query: query,
|
Query: query,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
||||||
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
||||||
@@ -1022,6 +1151,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1047,6 +1181,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Schema: schema,
|
Schema: schema,
|
||||||
Entity: entity,
|
Entity: entity,
|
||||||
TableName: tableName,
|
TableName: tableName,
|
||||||
|
Tx: h.db,
|
||||||
Model: model,
|
Model: model,
|
||||||
Options: options,
|
Options: options,
|
||||||
ID: id,
|
ID: id,
|
||||||
@@ -1116,12 +1251,24 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Ensure ID is in the data map for the update
|
// Ensure ID is in the data map for the update
|
||||||
dataMap[pkName] = targetID
|
dataMap[pkName] = targetID
|
||||||
|
|
||||||
// Create update query
|
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
||||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
// Get the type of the model, handling both pointer and non-pointer types
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
modelInstance := reflect.New(modelType).Interface()
|
||||||
|
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
||||||
|
return fmt.Errorf("failed to populate model from data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create update query using Model() to preserve custom types and driver.Valuer interfaces
|
||||||
|
query := tx.NewUpdate().Model(modelInstance)
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
hookCtx.Query = query
|
hookCtx.Query = query
|
||||||
|
hookCtx.Tx = tx
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1180,6 +1327,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1217,6 +1369,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemID,
|
ID: itemID,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@@ -1247,6 +1400,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1285,6 +1443,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemIDStr,
|
ID: itemIDStr,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@@ -1314,6 +1473,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1337,6 +1501,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemIDStr,
|
ID: itemIDStr,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@@ -1367,6 +1532,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1380,7 +1550,34 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Single delete with URL ID
|
// Single delete with URL ID
|
||||||
// Execute BeforeDelete hooks
|
if id == "" {
|
||||||
|
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for delete", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get primary key name
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// First, fetch the record that will be deleted
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
recordToDelete := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
|
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
logger.Warn("Record not found for delete: %s = %s", pkName, id)
|
||||||
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error("Error fetching record for delete: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Error fetching record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeDelete hooks with the record data
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Handler: h,
|
Handler: h,
|
||||||
@@ -1390,6 +1587,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: id,
|
ID: id,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
|
Data: recordToDelete,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@@ -1399,13 +1598,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewDelete().Table(tableName)
|
query := h.db.NewDelete().Table(tableName)
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
if id == "" {
|
|
||||||
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for delete", nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
hookCtx.Query = query
|
hookCtx.Query = query
|
||||||
@@ -1427,11 +1620,15 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterDelete hooks
|
// Check if the record was actually deleted
|
||||||
responseData := map[string]interface{}{
|
if result.RowsAffected() == 0 {
|
||||||
"deleted": result.RowsAffected(),
|
logger.Warn("No rows deleted for ID: %s", id)
|
||||||
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found or already deleted", nil)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
hookCtx.Result = responseData
|
|
||||||
|
// Execute AfterDelete hooks with the deleted record data
|
||||||
|
hookCtx.Result = recordToDelete
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||||
@@ -1440,7 +1637,13 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponse(w, responseData, nil)
|
// Return the deleted record data
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
h.sendResponse(w, recordToDelete, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeRecordWithRequest merges a database record with the original request data
|
// mergeRecordWithRequest merges a database record with the original request data
|
||||||
@@ -1936,14 +2139,20 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
|
|||||||
|
|
||||||
// sendResponseWithOptions sends a response with optional formatting
|
// sendResponseWithOptions sends a response with optional formatting
|
||||||
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
||||||
|
w.SetHeader("Content-Type", "application/json")
|
||||||
|
if data == nil {
|
||||||
|
data = map[string]interface{}{}
|
||||||
|
w.WriteHeader(http.StatusPartialContent)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
// Normalize single-record arrays to objects if requested
|
// Normalize single-record arrays to objects if requested
|
||||||
if options != nil && options.SingleRecordAsObject {
|
if options != nil && options.SingleRecordAsObject {
|
||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return data as-is without wrapping in common.Response
|
// Return data as-is without wrapping in common.Response
|
||||||
w.SetHeader("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
if err := w.WriteJSON(data); err != nil {
|
if err := w.WriteJSON(data); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1953,7 +2162,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return nil
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
@@ -1962,18 +2171,41 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
dataValue = dataValue.Elem()
|
dataValue = dataValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if it's a slice or array with exactly one element
|
// Check if it's a slice or array
|
||||||
if (dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array) && dataValue.Len() == 1 {
|
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||||
|
if dataValue.Len() == 1 {
|
||||||
// Return the single element
|
// Return the single element
|
||||||
return dataValue.Index(0).Interface()
|
return dataValue.Index(0).Interface()
|
||||||
|
} else if dataValue.Len() == 0 {
|
||||||
|
// Return empty object instead of empty array
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dataValue.Kind() == reflect.String {
|
||||||
|
str := dataValue.String()
|
||||||
|
if str == "" || str == "null" {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendFormattedResponse sends response with formatting options
|
// sendFormattedResponse sends response with formatting options
|
||||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||||
// Normalize single-record arrays to objects if requested
|
// Normalize single-record arrays to objects if requested
|
||||||
|
httpStatus := http.StatusOK
|
||||||
|
if data == nil {
|
||||||
|
data = map[string]interface{}{}
|
||||||
|
httpStatus = http.StatusPartialContent
|
||||||
|
} else {
|
||||||
|
dataLen := reflection.Len(data)
|
||||||
|
if dataLen == 0 {
|
||||||
|
httpStatus = http.StatusPartialContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if options.SingleRecordAsObject {
|
if options.SingleRecordAsObject {
|
||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
@@ -1992,7 +2224,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
switch options.ResponseFormat {
|
switch options.ResponseFormat {
|
||||||
case "simple":
|
case "simple":
|
||||||
// Simple format: just return the data array
|
// Simple format: just return the data array
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(httpStatus)
|
||||||
if err := w.WriteJSON(data); err != nil {
|
if err := w.WriteJSON(data); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2004,7 +2236,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
if metadata != nil {
|
if metadata != nil {
|
||||||
response["count"] = metadata.Total
|
response["count"] = metadata.Total
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(httpStatus)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2015,7 +2247,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
Data: data,
|
Data: data,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(httpStatus)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2073,8 +2305,15 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
|||||||
if strings.EqualFold(sort.Direction, "desc") {
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
direction = "DESC"
|
direction = "DESC"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if it's an expression (enclosed in brackets) - use directly without table prefix
|
||||||
|
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
|
sortParts = append(sortParts, fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
|
} else {
|
||||||
|
// Regular column - add table prefix
|
||||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
sortSQL = strings.Join(sortParts, ", ")
|
sortSQL = strings.Join(sortParts, ", ")
|
||||||
} else {
|
} else {
|
||||||
// Default sort by primary key
|
// Default sort by primary key
|
||||||
@@ -2282,6 +2521,55 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
|||||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||||
// Filter columns using the related model's validator
|
// Filter columns using the related model's validator
|
||||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||||
|
|
||||||
|
// Filter sort columns in the expand Sort string
|
||||||
|
if expand.Sort != "" {
|
||||||
|
sortFields := strings.Split(expand.Sort, ",")
|
||||||
|
validSortFields := make([]string, 0, len(sortFields))
|
||||||
|
for _, sortField := range sortFields {
|
||||||
|
sortField = strings.TrimSpace(sortField)
|
||||||
|
if sortField == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract column name (remove direction prefixes/suffixes)
|
||||||
|
colName := sortField
|
||||||
|
direction := ""
|
||||||
|
|
||||||
|
if strings.HasPrefix(sortField, "-") {
|
||||||
|
direction = "-"
|
||||||
|
colName = strings.TrimPrefix(sortField, "-")
|
||||||
|
} else if strings.HasPrefix(sortField, "+") {
|
||||||
|
direction = "+"
|
||||||
|
colName = strings.TrimPrefix(sortField, "+")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(strings.ToLower(colName), " desc") {
|
||||||
|
direction = " desc"
|
||||||
|
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
|
||||||
|
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
|
||||||
|
direction = " asc"
|
||||||
|
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
|
||||||
|
}
|
||||||
|
|
||||||
|
colName = strings.TrimSpace(colName)
|
||||||
|
|
||||||
|
// Validate the column name
|
||||||
|
if expandValidator.IsValidColumn(colName) {
|
||||||
|
validSortFields = append(validSortFields, direction+colName)
|
||||||
|
} else if strings.HasPrefix(colName, "(") && strings.HasSuffix(colName, ")") {
|
||||||
|
// Allow sort by expression/subquery, but validate for security
|
||||||
|
if common.IsSafeSortExpression(colName) {
|
||||||
|
validSortFields = append(validSortFields, direction+colName)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Unsafe sort expression in expand '%s' removed: '%s'", expand.Relation, colName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in expand '%s' sort '%s' removed", expand.Relation, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filteredExpand.Sort = strings.Join(validSortFields, ",")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// If we can't find the relationship, log a warning and skip column filtering
|
// If we can't find the relationship, log a warning and skip column filtering
|
||||||
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||||
|
|||||||
@@ -529,19 +529,47 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseCommaSeparated parses comma-separated values and trims whitespace
|
// parseCommaSeparated parses comma-separated values and trims whitespace
|
||||||
|
// It respects bracket nesting and only splits on commas outside of parentheses
|
||||||
func (h *Handler) parseCommaSeparated(value string) []string {
|
func (h *Handler) parseCommaSeparated(value string) []string {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(value, ",")
|
result := make([]string, 0)
|
||||||
result := make([]string, 0, len(parts))
|
var current strings.Builder
|
||||||
for _, part := range parts {
|
nestingLevel := 0
|
||||||
part = strings.TrimSpace(part)
|
|
||||||
|
for _, char := range value {
|
||||||
|
switch char {
|
||||||
|
case '(':
|
||||||
|
nestingLevel++
|
||||||
|
current.WriteRune(char)
|
||||||
|
case ')':
|
||||||
|
nestingLevel--
|
||||||
|
current.WriteRune(char)
|
||||||
|
case ',':
|
||||||
|
if nestingLevel == 0 {
|
||||||
|
// We're outside all brackets, so split here
|
||||||
|
part := strings.TrimSpace(current.String())
|
||||||
if part != "" {
|
if part != "" {
|
||||||
result = append(result, part)
|
result = append(result, part)
|
||||||
}
|
}
|
||||||
|
current.Reset()
|
||||||
|
} else {
|
||||||
|
// Inside brackets, keep the comma
|
||||||
|
current.WriteRune(char)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
current.WriteRune(char)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the last part
|
||||||
|
part := strings.TrimSpace(current.String())
|
||||||
|
if part != "" {
|
||||||
|
result = append(result, part)
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ type HookContext struct {
|
|||||||
|
|
||||||
// Response writer - allows hooks to modify response
|
// Response writer - allows hooks to modify response
|
||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
|
|
||||||
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
|
Tx common.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// HookFunc is the signature for hook functions
|
// HookFunc is the signature for hook functions
|
||||||
|
|||||||
@@ -150,6 +150,50 @@ func ExampleRelatedDataHook(ctx *HookContext) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExampleTxHook demonstrates using the Tx field to execute additional SQL queries
|
||||||
|
// The Tx field provides access to the database/transaction for custom queries
|
||||||
|
func ExampleTxHook(ctx *HookContext) error {
|
||||||
|
// Example: Execute additional SQL operations alongside the main query
|
||||||
|
// This is useful for maintaining data consistency, updating related records, etc.
|
||||||
|
|
||||||
|
if ctx.Entity == "orders" && ctx.Data != nil {
|
||||||
|
// Example: Update inventory when an order is created
|
||||||
|
// Extract product ID and quantity from the order data
|
||||||
|
// dataMap, ok := ctx.Data.(map[string]interface{})
|
||||||
|
// if !ok {
|
||||||
|
// return fmt.Errorf("invalid data format")
|
||||||
|
// }
|
||||||
|
// productID := dataMap["product_id"]
|
||||||
|
// quantity := dataMap["quantity"]
|
||||||
|
|
||||||
|
// Use ctx.Tx to execute additional SQL queries
|
||||||
|
// The Tx field contains the same database/transaction as the main operation
|
||||||
|
// If inside a transaction, your queries will be part of the same transaction
|
||||||
|
// query := ctx.Tx.NewUpdate().
|
||||||
|
// Table("inventory").
|
||||||
|
// Set("quantity = quantity - ?", quantity).
|
||||||
|
// Where("product_id = ?", productID)
|
||||||
|
//
|
||||||
|
// if _, err := query.Exec(ctx.Context); err != nil {
|
||||||
|
// logger.Error("Failed to update inventory: %v", err)
|
||||||
|
// return fmt.Errorf("failed to update inventory: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// You can also execute raw SQL using ctx.Tx
|
||||||
|
// var result []map[string]interface{}
|
||||||
|
// err := ctx.Tx.Query(ctx.Context, &result,
|
||||||
|
// "INSERT INTO order_history (order_id, status) VALUES (?, ?)",
|
||||||
|
// orderID, "pending")
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to insert order history: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
logger.Debug("Executed additional SQL for order entity")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetupExampleHooks demonstrates how to register hooks on a handler
|
// SetupExampleHooks demonstrates how to register hooks on a handler
|
||||||
func SetupExampleHooks(handler *Handler) {
|
func SetupExampleHooks(handler *Handler) {
|
||||||
hooks := handler.Hooks()
|
hooks := handler.Hooks()
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
//go:build integration
|
||||||
// +build integration
|
// +build integration
|
||||||
|
|
||||||
package restheadspec
|
package restheadspec
|
||||||
@@ -401,7 +402,7 @@ func TestIntegration_GetMetadata(t *testing.T) {
|
|||||||
|
|
||||||
muxRouter.ServeHTTP(w, req)
|
muxRouter.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -492,7 +493,7 @@ func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
|
|||||||
|
|
||||||
muxRouter.ServeHTTP(w, req)
|
muxRouter.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||||
t.Errorf("Expected status 200, got %d", w.Code)
|
t.Errorf("Expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,233 +1,314 @@
|
|||||||
# Server Package
|
# Server Package
|
||||||
|
|
||||||
Graceful HTTP server with request draining and shutdown coordination.
|
Production-ready HTTP server manager with graceful shutdown, request draining, and comprehensive TLS/HTTPS support.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
✅ **Multiple Server Management** - Run multiple HTTP/HTTPS servers concurrently
|
||||||
|
✅ **Graceful Shutdown** - Handles SIGINT/SIGTERM with request draining
|
||||||
|
✅ **Automatic Request Rejection** - New requests get 503 during shutdown
|
||||||
|
✅ **Health & Readiness Endpoints** - Kubernetes-ready health checks
|
||||||
|
✅ **Shutdown Callbacks** - Register cleanup functions (DB, cache, metrics)
|
||||||
|
✅ **Comprehensive TLS Support**:
|
||||||
|
- Certificate files (production)
|
||||||
|
- Self-signed certificates (development/testing)
|
||||||
|
- Let's Encrypt / AutoTLS (automatic certificate management)
|
||||||
|
✅ **GZIP Compression** - Optional response compression
|
||||||
|
✅ **Panic Recovery** - Automatic panic recovery middleware
|
||||||
|
✅ **Configurable Timeouts** - Read, write, idle, drain, and shutdown timeouts
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
### Single Server
|
||||||
|
|
||||||
```go
|
```go
|
||||||
import "github.com/bitechdev/ResolveSpec/pkg/server"
|
import "github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
|
||||||
// Create server
|
// Create server manager
|
||||||
srv := server.NewGracefulServer(server.Config{
|
mgr := server.NewManager()
|
||||||
Addr: ":8080",
|
|
||||||
Handler: router,
|
// Add server
|
||||||
|
_, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: myRouter,
|
||||||
|
GZIP: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Start server (blocks until shutdown signal)
|
// Start and wait for shutdown signal
|
||||||
if err := srv.ListenAndServe(); err != nil {
|
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Features
|
### Multiple Servers
|
||||||
|
|
||||||
✅ Graceful shutdown on SIGINT/SIGTERM
|
```go
|
||||||
✅ Request draining (waits for in-flight requests)
|
mgr := server.NewManager()
|
||||||
✅ Automatic request rejection during shutdown
|
|
||||||
✅ Health and readiness endpoints
|
// Public API
|
||||||
✅ Shutdown callbacks for cleanup
|
mgr.Add(server.Config{
|
||||||
✅ Configurable timeouts
|
Name: "public-api",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: publicRouter,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Admin API
|
||||||
|
mgr.Add(server.Config{
|
||||||
|
Name: "admin-api",
|
||||||
|
Port: 8081,
|
||||||
|
Handler: adminRouter,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start all and wait
|
||||||
|
mgr.ServeWithGracefulShutdown()
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTPS/TLS Configuration
|
||||||
|
|
||||||
|
### Option 1: Certificate Files (Production)
|
||||||
|
|
||||||
|
```go
|
||||||
|
mgr.Add(server.Config{
|
||||||
|
Name: "https-server",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 443,
|
||||||
|
Handler: handler,
|
||||||
|
SSLCert: "/etc/ssl/certs/server.crt",
|
||||||
|
SSLKey: "/etc/ssl/private/server.key",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Self-Signed Certificate (Development)
|
||||||
|
|
||||||
|
```go
|
||||||
|
mgr.Add(server.Config{
|
||||||
|
Name: "dev-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8443,
|
||||||
|
Handler: handler,
|
||||||
|
SelfSignedSSL: true, // Auto-generates certificate
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 3: Let's Encrypt / AutoTLS (Production)
|
||||||
|
|
||||||
|
```go
|
||||||
|
mgr.Add(server.Config{
|
||||||
|
Name: "prod-server",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 443,
|
||||||
|
Handler: handler,
|
||||||
|
AutoTLS: true,
|
||||||
|
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||||
|
AutoTLSEmail: "admin@example.com",
|
||||||
|
AutoTLSCacheDir: "./certs-cache", // Certificate cache directory
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
```go
|
```go
|
||||||
config := server.Config{
|
server.Config{
|
||||||
// Server address
|
// Basic configuration
|
||||||
Addr: ":8080",
|
Name: "my-server", // Server name (required)
|
||||||
|
Host: "0.0.0.0", // Bind address
|
||||||
|
Port: 8080, // Port (required)
|
||||||
|
Handler: myRouter, // HTTP handler (required)
|
||||||
|
Description: "My API server", // Optional description
|
||||||
|
|
||||||
// HTTP handler
|
// Features
|
||||||
Handler: myRouter,
|
GZIP: true, // Enable GZIP compression
|
||||||
|
|
||||||
// Maximum time for graceful shutdown (default: 30s)
|
// TLS/HTTPS (choose one option)
|
||||||
ShutdownTimeout: 30 * time.Second,
|
SSLCert: "/path/to/cert.pem", // Certificate file
|
||||||
|
SSLKey: "/path/to/key.pem", // Key file
|
||||||
|
SelfSignedSSL: false, // Auto-generate self-signed cert
|
||||||
|
AutoTLS: false, // Let's Encrypt
|
||||||
|
AutoTLSDomains: []string{}, // Domains for AutoTLS
|
||||||
|
AutoTLSEmail: "", // Email for Let's Encrypt
|
||||||
|
AutoTLSCacheDir: "./certs-cache", // Cert cache directory
|
||||||
|
|
||||||
// Time to wait for in-flight requests (default: 25s)
|
// Timeouts
|
||||||
DrainTimeout: 25 * time.Second,
|
ShutdownTimeout: 30 * time.Second, // Max shutdown time
|
||||||
|
DrainTimeout: 25 * time.Second, // Request drain timeout
|
||||||
// Request read timeout (default: 10s)
|
ReadTimeout: 15 * time.Second, // Request read timeout
|
||||||
ReadTimeout: 10 * time.Second,
|
WriteTimeout: 15 * time.Second, // Response write timeout
|
||||||
|
IdleTimeout: 60 * time.Second, // Idle connection timeout
|
||||||
// Response write timeout (default: 10s)
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
|
|
||||||
// Idle connection timeout (default: 120s)
|
|
||||||
IdleTimeout: 120 * time.Second,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := server.NewGracefulServer(config)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Shutdown Behavior
|
## Graceful Shutdown
|
||||||
|
|
||||||
**Signal received (SIGINT/SIGTERM):**
|
### Automatic (Recommended)
|
||||||
|
|
||||||
1. **Mark as shutting down** - New requests get 503
|
```go
|
||||||
2. **Drain requests** - Wait up to `DrainTimeout` for in-flight requests
|
mgr := server.NewManager()
|
||||||
3. **Shutdown server** - Close listeners and connections
|
|
||||||
4. **Execute callbacks** - Run registered cleanup functions
|
|
||||||
|
|
||||||
|
// Add servers...
|
||||||
|
|
||||||
|
// This blocks until SIGINT/SIGTERM
|
||||||
|
mgr.ServeWithGracefulShutdown()
|
||||||
```
|
```
|
||||||
Time Event
|
|
||||||
─────────────────────────────────────────
|
|
||||||
0s Signal received: SIGTERM
|
|
||||||
├─ Mark as shutting down
|
|
||||||
├─ Reject new requests (503)
|
|
||||||
└─ Start draining...
|
|
||||||
|
|
||||||
1s In-flight: 50 requests
|
### Manual Control
|
||||||
2s In-flight: 32 requests
|
|
||||||
3s In-flight: 12 requests
|
|
||||||
4s In-flight: 3 requests
|
|
||||||
5s In-flight: 0 requests ✓
|
|
||||||
└─ All requests drained
|
|
||||||
|
|
||||||
5s Execute shutdown callbacks
|
```go
|
||||||
6s Shutdown complete
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Add and start servers
|
||||||
|
mgr.StartAll()
|
||||||
|
|
||||||
|
// Later... stop gracefully
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := mgr.StopAllWithContext(ctx); err != nil {
|
||||||
|
log.Printf("Shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Shutdown Callbacks
|
||||||
|
|
||||||
|
Register cleanup functions to run during shutdown:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Close database
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Closing database...")
|
||||||
|
return db.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Flush metrics
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Flushing metrics...")
|
||||||
|
return metrics.Flush(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Close cache
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Closing cache...")
|
||||||
|
return cache.Close()
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## Health Checks
|
## Health Checks
|
||||||
|
|
||||||
### Health Endpoint
|
### Adding Health Endpoints
|
||||||
|
|
||||||
Returns 200 when healthy, 503 when shutting down:
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
instance, _ := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: router,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add health endpoints to your router
|
||||||
|
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||||
|
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||||
```
|
```
|
||||||
|
|
||||||
**Response (healthy):**
|
### Health Endpoint
|
||||||
|
|
||||||
|
Returns server health status:
|
||||||
|
|
||||||
|
**Healthy (200 OK):**
|
||||||
```json
|
```json
|
||||||
{"status":"healthy"}
|
{"status":"healthy"}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Response (shutting down):**
|
**Shutting Down (503 Service Unavailable):**
|
||||||
```json
|
```json
|
||||||
{"status":"shutting_down"}
|
{"status":"shutting_down"}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Readiness Endpoint
|
### Readiness Endpoint
|
||||||
|
|
||||||
Includes in-flight request count:
|
Returns readiness with in-flight request count:
|
||||||
|
|
||||||
```go
|
**Ready (200 OK):**
|
||||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
|
||||||
```
|
|
||||||
|
|
||||||
**Response:**
|
|
||||||
```json
|
```json
|
||||||
{"ready":true,"in_flight_requests":12}
|
{"ready":true,"in_flight_requests":12}
|
||||||
```
|
```
|
||||||
|
|
||||||
**During shutdown:**
|
**Not Ready (503 Service Unavailable):**
|
||||||
```json
|
```json
|
||||||
{"ready":false,"reason":"shutting_down"}
|
{"ready":false,"reason":"shutting_down"}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Shutdown Callbacks
|
## Shutdown Behavior
|
||||||
|
|
||||||
Register cleanup functions to run during shutdown:
|
When a shutdown signal (SIGINT/SIGTERM) is received:
|
||||||
|
|
||||||
```go
|
1. **Mark as shutting down** → New requests get 503
|
||||||
// Close database
|
2. **Execute callbacks** → Run cleanup functions
|
||||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
3. **Drain requests** → Wait up to `DrainTimeout` for in-flight requests
|
||||||
logger.Info("Closing database connection...")
|
4. **Shutdown servers** → Close listeners and connections
|
||||||
return db.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Flush metrics
|
```
|
||||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
Time Event
|
||||||
logger.Info("Flushing metrics...")
|
─────────────────────────────────────────
|
||||||
return metricsProvider.Flush(ctx)
|
0s Signal received: SIGTERM
|
||||||
})
|
├─ Mark servers as shutting down
|
||||||
|
├─ Reject new requests (503)
|
||||||
|
└─ Execute shutdown callbacks
|
||||||
|
|
||||||
// Close cache
|
1s Callbacks complete
|
||||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
└─ Start draining requests...
|
||||||
logger.Info("Closing cache...")
|
|
||||||
return cache.Close()
|
2s In-flight: 50 requests
|
||||||
})
|
3s In-flight: 32 requests
|
||||||
|
4s In-flight: 12 requests
|
||||||
|
5s In-flight: 3 requests
|
||||||
|
6s In-flight: 0 requests ✓
|
||||||
|
└─ All requests drained
|
||||||
|
|
||||||
|
6s Shutdown servers
|
||||||
|
7s All servers stopped ✓
|
||||||
```
|
```
|
||||||
|
|
||||||
## Complete Example
|
## Server Management
|
||||||
|
|
||||||
|
### Get Server Instance
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
instance, err := mgr.Get("api-server")
|
||||||
|
if err != nil {
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// Initialize metrics
|
|
||||||
metricsProvider := metrics.NewPrometheusProvider()
|
|
||||||
metrics.SetProvider(metricsProvider)
|
|
||||||
|
|
||||||
// Create router
|
|
||||||
router := mux.NewRouter()
|
|
||||||
|
|
||||||
// Apply middleware
|
|
||||||
rateLimiter := middleware.NewRateLimiter(100, 20)
|
|
||||||
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size10MB)
|
|
||||||
sanitizer := middleware.DefaultSanitizer()
|
|
||||||
|
|
||||||
router.Use(rateLimiter.Middleware)
|
|
||||||
router.Use(sizeLimiter.Middleware)
|
|
||||||
router.Use(sanitizer.Middleware)
|
|
||||||
router.Use(metricsProvider.Middleware)
|
|
||||||
|
|
||||||
// API routes
|
|
||||||
router.HandleFunc("/api/data", dataHandler)
|
|
||||||
|
|
||||||
// Create graceful server
|
|
||||||
srv := server.NewGracefulServer(server.Config{
|
|
||||||
Addr: ":8080",
|
|
||||||
Handler: router,
|
|
||||||
ShutdownTimeout: 30 * time.Second,
|
|
||||||
DrainTimeout: 25 * time.Second,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Health checks
|
|
||||||
router.HandleFunc("/health", srv.HealthCheckHandler())
|
|
||||||
router.HandleFunc("/ready", srv.ReadinessHandler())
|
|
||||||
|
|
||||||
// Metrics endpoint
|
|
||||||
router.Handle("/metrics", metricsProvider.Handler())
|
|
||||||
|
|
||||||
// Register shutdown callbacks
|
|
||||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
|
||||||
log.Println("Cleanup: Flushing metrics...")
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
|
||||||
log.Println("Cleanup: Closing database...")
|
|
||||||
// return db.Close()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start server (blocks until shutdown)
|
|
||||||
log.Printf("Starting server on :8080")
|
|
||||||
if err := srv.ListenAndServe(); err != nil {
|
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for shutdown to complete
|
|
||||||
srv.Wait()
|
|
||||||
log.Println("Server stopped")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
// Check status
|
||||||
// Your handler logic
|
fmt.Printf("Address: %s\n", instance.Addr())
|
||||||
time.Sleep(100 * time.Millisecond) // Simulate work
|
fmt.Printf("Name: %s\n", instance.Name())
|
||||||
w.WriteHeader(http.StatusOK)
|
fmt.Printf("In-flight: %d\n", instance.InFlightRequests())
|
||||||
w.Write([]byte(`{"message":"success"}`))
|
fmt.Printf("Shutting down: %v\n", instance.IsShuttingDown())
|
||||||
|
```
|
||||||
|
|
||||||
|
### List All Servers
|
||||||
|
|
||||||
|
```go
|
||||||
|
instances := mgr.List()
|
||||||
|
for _, instance := range instances {
|
||||||
|
fmt.Printf("Server: %s at %s\n", instance.Name(), instance.Addr())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Remove Server
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Stop and remove a server
|
||||||
|
if err := mgr.Remove("api-server"); err != nil {
|
||||||
|
log.Printf("Error removing server: %v", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restart All Servers
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Gracefully restart all servers
|
||||||
|
if err := mgr.RestartAll(); err != nil {
|
||||||
|
log.Printf("Error restarting: %v", err)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -250,23 +331,21 @@ spec:
|
|||||||
ports:
|
ports:
|
||||||
- containerPort: 8080
|
- containerPort: 8080
|
||||||
|
|
||||||
# Liveness probe - is app running?
|
# Liveness probe
|
||||||
livenessProbe:
|
livenessProbe:
|
||||||
httpGet:
|
httpGet:
|
||||||
path: /health
|
path: /health
|
||||||
port: 8080
|
port: 8080
|
||||||
initialDelaySeconds: 10
|
initialDelaySeconds: 10
|
||||||
periodSeconds: 10
|
periodSeconds: 10
|
||||||
timeoutSeconds: 5
|
|
||||||
|
|
||||||
# Readiness probe - can app handle traffic?
|
# Readiness probe
|
||||||
readinessProbe:
|
readinessProbe:
|
||||||
httpGet:
|
httpGet:
|
||||||
path: /ready
|
path: /ready
|
||||||
port: 8080
|
port: 8080
|
||||||
initialDelaySeconds: 5
|
initialDelaySeconds: 5
|
||||||
periodSeconds: 5
|
periodSeconds: 5
|
||||||
timeoutSeconds: 3
|
|
||||||
|
|
||||||
# Graceful shutdown
|
# Graceful shutdown
|
||||||
lifecycle:
|
lifecycle:
|
||||||
@@ -274,26 +353,12 @@ spec:
|
|||||||
exec:
|
exec:
|
||||||
command: ["/bin/sh", "-c", "sleep 5"]
|
command: ["/bin/sh", "-c", "sleep 5"]
|
||||||
|
|
||||||
# Environment
|
|
||||||
env:
|
env:
|
||||||
- name: SHUTDOWN_TIMEOUT
|
- name: SHUTDOWN_TIMEOUT
|
||||||
value: "30"
|
value: "30"
|
||||||
```
|
|
||||||
|
|
||||||
### Service
|
# Allow time for graceful shutdown
|
||||||
|
terminationGracePeriodSeconds: 35
|
||||||
```yaml
|
|
||||||
apiVersion: v1
|
|
||||||
kind: Service
|
|
||||||
metadata:
|
|
||||||
name: myapp
|
|
||||||
spec:
|
|
||||||
selector:
|
|
||||||
app: myapp
|
|
||||||
ports:
|
|
||||||
- port: 80
|
|
||||||
targetPort: 8080
|
|
||||||
type: LoadBalancer
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Docker Compose
|
## Docker Compose
|
||||||
@@ -312,8 +377,70 @@ services:
|
|||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 10s
|
stop_grace_period: 35s
|
||||||
stop_grace_period: 35s # Slightly longer than shutdown timeout
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Create server manager
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Register shutdown callbacks
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
log.Println("Cleanup: Closing database...")
|
||||||
|
// return db.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
router := http.NewServeMux()
|
||||||
|
router.HandleFunc("/api/data", dataHandler)
|
||||||
|
|
||||||
|
// Add server
|
||||||
|
instance, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: router,
|
||||||
|
GZIP: true,
|
||||||
|
ShutdownTimeout: 30 * time.Second,
|
||||||
|
DrainTimeout: 25 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add health endpoints
|
||||||
|
router.HandleFunc("/health", instance.HealthCheckHandler())
|
||||||
|
router.HandleFunc("/ready", instance.ReadinessHandler())
|
||||||
|
|
||||||
|
// Start and wait for shutdown
|
||||||
|
log.Println("Starting server on :8080")
|
||||||
|
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||||
|
log.Printf("Server stopped: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Server shutdown complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond) // Simulate work
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"message":"success"}`))
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing Graceful Shutdown
|
## Testing Graceful Shutdown
|
||||||
@@ -330,7 +457,7 @@ SERVER_PID=$!
|
|||||||
# Wait for server to start
|
# Wait for server to start
|
||||||
sleep 2
|
sleep 2
|
||||||
|
|
||||||
# Send some requests
|
# Send requests
|
||||||
for i in {1..10}; do
|
for i in {1..10}; do
|
||||||
curl http://localhost:8080/api/data &
|
curl http://localhost:8080/api/data &
|
||||||
done
|
done
|
||||||
@@ -341,7 +468,7 @@ sleep 1
|
|||||||
# Send shutdown signal
|
# Send shutdown signal
|
||||||
kill -TERM $SERVER_PID
|
kill -TERM $SERVER_PID
|
||||||
|
|
||||||
# Try to send more requests (should get 503)
|
# Try more requests (should get 503)
|
||||||
curl -v http://localhost:8080/api/data
|
curl -v http://localhost:8080/api/data
|
||||||
|
|
||||||
# Wait for server to stop
|
# Wait for server to stop
|
||||||
@@ -349,101 +476,13 @@ wait $SERVER_PID
|
|||||||
echo "Server stopped gracefully"
|
echo "Server stopped gracefully"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Expected Output
|
|
||||||
|
|
||||||
```
|
|
||||||
Starting server on :8080
|
|
||||||
Received signal: terminated, initiating graceful shutdown
|
|
||||||
Starting graceful shutdown...
|
|
||||||
Waiting for 8 in-flight requests to complete...
|
|
||||||
Waiting for 4 in-flight requests to complete...
|
|
||||||
Waiting for 1 in-flight requests to complete...
|
|
||||||
All requests drained in 2.3s
|
|
||||||
Cleanup: Flushing metrics...
|
|
||||||
Cleanup: Closing database...
|
|
||||||
Shutting down HTTP server...
|
|
||||||
Graceful shutdown complete
|
|
||||||
Server stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
## Monitoring In-Flight Requests
|
|
||||||
|
|
||||||
```go
|
|
||||||
// Get current in-flight count
|
|
||||||
count := srv.InFlightRequests()
|
|
||||||
fmt.Printf("In-flight requests: %d\n", count)
|
|
||||||
|
|
||||||
// Check if shutting down
|
|
||||||
if srv.IsShuttingDown() {
|
|
||||||
fmt.Println("Server is shutting down")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Advanced Usage
|
|
||||||
|
|
||||||
### Custom Shutdown Logic
|
|
||||||
|
|
||||||
```go
|
|
||||||
// Implement custom shutdown
|
|
||||||
go func() {
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
|
||||||
|
|
||||||
<-sigChan
|
|
||||||
log.Println("Shutdown signal received")
|
|
||||||
|
|
||||||
// Custom pre-shutdown logic
|
|
||||||
log.Println("Running custom cleanup...")
|
|
||||||
|
|
||||||
// Shutdown with callbacks
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := srv.ShutdownWithCallbacks(ctx); err != nil {
|
|
||||||
log.Printf("Shutdown error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start server
|
|
||||||
srv.server.ListenAndServe()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiple Servers
|
|
||||||
|
|
||||||
```go
|
|
||||||
// HTTP server
|
|
||||||
httpSrv := server.NewGracefulServer(server.Config{
|
|
||||||
Addr: ":8080",
|
|
||||||
Handler: httpRouter,
|
|
||||||
})
|
|
||||||
|
|
||||||
// HTTPS server
|
|
||||||
httpsSrv := server.NewGracefulServer(server.Config{
|
|
||||||
Addr: ":8443",
|
|
||||||
Handler: httpsRouter,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Start both
|
|
||||||
go httpSrv.ListenAndServe()
|
|
||||||
go httpsSrv.ListenAndServe()
|
|
||||||
|
|
||||||
// Shutdown both on signal
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, os.Interrupt)
|
|
||||||
<-sigChan
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
httpSrv.Shutdown(ctx)
|
|
||||||
httpsSrv.Shutdown(ctx)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
1. **Set appropriate timeouts**
|
1. **Set appropriate timeouts**
|
||||||
- `DrainTimeout` < `ShutdownTimeout`
|
- `DrainTimeout` < `ShutdownTimeout`
|
||||||
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
|
- `ShutdownTimeout` < Kubernetes `terminationGracePeriodSeconds`
|
||||||
|
|
||||||
2. **Register cleanup callbacks** for:
|
2. **Use shutdown callbacks** for:
|
||||||
- Database connections
|
- Database connections
|
||||||
- Message queues
|
- Message queues
|
||||||
- Metrics flushing
|
- Metrics flushing
|
||||||
@@ -458,7 +497,12 @@ httpsSrv.Shutdown(ctx)
|
|||||||
- Set `preStop` hook in Kubernetes (5-10s delay)
|
- Set `preStop` hook in Kubernetes (5-10s delay)
|
||||||
- Allows load balancer to deregister before shutdown
|
- Allows load balancer to deregister before shutdown
|
||||||
|
|
||||||
5. **Monitoring**
|
5. **HTTPS in production**
|
||||||
|
- Use AutoTLS for public-facing services
|
||||||
|
- Use certificate files for enterprise PKI
|
||||||
|
- Use self-signed only for development/testing
|
||||||
|
|
||||||
|
6. **Monitoring**
|
||||||
- Track in-flight requests in metrics
|
- Track in-flight requests in metrics
|
||||||
- Alert on slow drains
|
- Alert on slow drains
|
||||||
- Monitor shutdown duration
|
- Monitor shutdown duration
|
||||||
@@ -470,24 +514,63 @@ httpsSrv.Shutdown(ctx)
|
|||||||
```go
|
```go
|
||||||
// Increase drain timeout
|
// Increase drain timeout
|
||||||
config.DrainTimeout = 60 * time.Second
|
config.DrainTimeout = 60 * time.Second
|
||||||
|
config.ShutdownTimeout = 65 * time.Second
|
||||||
```
|
```
|
||||||
|
|
||||||
### Requests Still Timing Out
|
### Requests Timing Out
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// Increase write timeout
|
// Increase write timeout
|
||||||
config.WriteTimeout = 30 * time.Second
|
config.WriteTimeout = 30 * time.Second
|
||||||
```
|
```
|
||||||
|
|
||||||
### Force Shutdown Not Working
|
### Certificate Issues
|
||||||
|
|
||||||
The server will force shutdown after `ShutdownTimeout` even if requests are still in-flight. Adjust timeouts as needed.
|
```go
|
||||||
|
// Verify certificate files exist and are readable
|
||||||
### Debugging Shutdown
|
if _, err := os.Stat(config.SSLCert); err != nil {
|
||||||
|
log.Fatalf("Certificate not found: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For AutoTLS, ensure:
|
||||||
|
// - Port 443 is accessible
|
||||||
|
// - Domains resolve to server IP
|
||||||
|
// - Cache directory is writable
|
||||||
|
```
|
||||||
|
|
||||||
|
### Debug Logging
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// Enable debug logging
|
|
||||||
import "github.com/bitechdev/ResolveSpec/pkg/logger"
|
import "github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
|
||||||
|
// Enable debug logging
|
||||||
logger.SetLevel("debug")
|
logger.SetLevel("debug")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Manager Methods
|
||||||
|
|
||||||
|
- `NewManager()` - Create new server manager
|
||||||
|
- `Add(cfg Config)` - Register server instance
|
||||||
|
- `Get(name string)` - Get server by name
|
||||||
|
- `Remove(name string)` - Stop and remove server
|
||||||
|
- `StartAll()` - Start all registered servers
|
||||||
|
- `StopAll()` - Stop all servers gracefully
|
||||||
|
- `StopAllWithContext(ctx)` - Stop with timeout
|
||||||
|
- `RestartAll()` - Restart all servers
|
||||||
|
- `List()` - Get all server instances
|
||||||
|
- `ServeWithGracefulShutdown()` - Start and block until shutdown
|
||||||
|
- `RegisterShutdownCallback(cb)` - Register cleanup function
|
||||||
|
|
||||||
|
### Instance Methods
|
||||||
|
|
||||||
|
- `Start()` - Start the server
|
||||||
|
- `Stop(ctx)` - Stop gracefully
|
||||||
|
- `Addr()` - Get server address
|
||||||
|
- `Name()` - Get server name
|
||||||
|
- `HealthCheckHandler()` - Get health handler
|
||||||
|
- `ReadinessHandler()` - Get readiness handler
|
||||||
|
- `InFlightRequests()` - Get in-flight count
|
||||||
|
- `IsShuttingDown()` - Check shutdown status
|
||||||
|
- `Wait()` - Block until shutdown complete
|
||||||
|
|||||||
294
pkg/server/example_test.go
Normal file
294
pkg/server/example_test.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExampleManager_basic demonstrates basic server manager usage
|
||||||
|
func ExampleManager_basic() {
|
||||||
|
// Create a server manager
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Define a simple handler
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintln(w, "Hello from server!")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add an HTTP server
|
||||||
|
_, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: handler,
|
||||||
|
GZIP: true, // Enable GZIP compression
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all servers
|
||||||
|
if err := mgr.StartAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server is now running...
|
||||||
|
// When done, stop gracefully
|
||||||
|
if err := mgr.StopAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleManager_https demonstrates HTTPS configurations
|
||||||
|
func ExampleManager_https() {
|
||||||
|
mgr := server.NewManager()
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "Secure connection!")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Option 1: Use certificate files
|
||||||
|
_, err := mgr.Add(server.Config{
|
||||||
|
Name: "https-server-files",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8443,
|
||||||
|
Handler: handler,
|
||||||
|
SSLCert: "/path/to/cert.pem",
|
||||||
|
SSLKey: "/path/to/key.pem",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 2: Self-signed certificate (for development)
|
||||||
|
_, err = mgr.Add(server.Config{
|
||||||
|
Name: "https-server-self-signed",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8444,
|
||||||
|
Handler: handler,
|
||||||
|
SelfSignedSSL: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 3: Let's Encrypt / AutoTLS (for production)
|
||||||
|
_, err = mgr.Add(server.Config{
|
||||||
|
Name: "https-server-letsencrypt",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 443,
|
||||||
|
Handler: handler,
|
||||||
|
AutoTLS: true,
|
||||||
|
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||||
|
AutoTLSEmail: "admin@example.com",
|
||||||
|
AutoTLSCacheDir: "./certs-cache",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all servers
|
||||||
|
if err := mgr.StartAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
mgr.StopAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleManager_gracefulShutdown demonstrates graceful shutdown with callbacks
|
||||||
|
func ExampleManager_gracefulShutdown() {
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Register shutdown callbacks for cleanup tasks
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
fmt.Println("Closing database connections...")
|
||||||
|
// Close your database here
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
mgr.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
fmt.Println("Flushing metrics...")
|
||||||
|
// Flush metrics here
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add server with custom timeouts
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate some work
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
fmt.Fprintln(w, "Done!")
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: handler,
|
||||||
|
ShutdownTimeout: 30 * time.Second, // Max time for shutdown
|
||||||
|
DrainTimeout: 25 * time.Second, // Time to wait for in-flight requests
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
|
||||||
|
// This will automatically handle graceful shutdown with callbacks
|
||||||
|
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||||
|
fmt.Printf("Shutdown completed: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleManager_healthChecks demonstrates health and readiness endpoints
|
||||||
|
func ExampleManager_healthChecks() {
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Create a router with health endpoints
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "Data endpoint")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add server
|
||||||
|
instance, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: mux,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add health and readiness endpoints
|
||||||
|
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||||
|
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
if err := mgr.StartAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Health check returns:
|
||||||
|
// - 200 OK with {"status":"healthy"} when healthy
|
||||||
|
// - 503 Service Unavailable with {"status":"shutting_down"} when shutting down
|
||||||
|
|
||||||
|
// Readiness check returns:
|
||||||
|
// - 200 OK with {"ready":true,"in_flight_requests":N} when ready
|
||||||
|
// - 503 Service Unavailable with {"ready":false,"reason":"shutting_down"} when shutting down
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
mgr.StopAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleManager_multipleServers demonstrates running multiple servers
|
||||||
|
func ExampleManager_multipleServers() {
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
// Public API server
|
||||||
|
publicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "Public API")
|
||||||
|
})
|
||||||
|
_, err := mgr.Add(server.Config{
|
||||||
|
Name: "public-api",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: publicHandler,
|
||||||
|
GZIP: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin API server (different port)
|
||||||
|
adminHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "Admin API")
|
||||||
|
})
|
||||||
|
_, err = mgr.Add(server.Config{
|
||||||
|
Name: "admin-api",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8081,
|
||||||
|
Handler: adminHandler,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics server (internal only)
|
||||||
|
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "Metrics data")
|
||||||
|
})
|
||||||
|
_, err = mgr.Add(server.Config{
|
||||||
|
Name: "metrics",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 9090,
|
||||||
|
Handler: metricsHandler,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all servers at once
|
||||||
|
if err := mgr.StartAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get specific server instance
|
||||||
|
publicInstance, err := mgr.Get("public-api")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
|
||||||
|
|
||||||
|
// List all servers
|
||||||
|
instances := mgr.List()
|
||||||
|
fmt.Printf("Running %d servers\n", len(instances))
|
||||||
|
|
||||||
|
// Stop all servers gracefully (in parallel)
|
||||||
|
if err := mgr.StopAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleManager_monitoring demonstrates monitoring server state
|
||||||
|
func ExampleManager_monitoring() {
|
||||||
|
mgr := server.NewManager()
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||||
|
fmt.Fprintln(w, "Done")
|
||||||
|
})
|
||||||
|
|
||||||
|
instance, err := mgr.Add(server.Config{
|
||||||
|
Name: "api-server",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 8080,
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.StartAll(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check server status
|
||||||
|
fmt.Printf("Server address: %s\n", instance.Addr())
|
||||||
|
fmt.Printf("Server name: %s\n", instance.Name())
|
||||||
|
fmt.Printf("Is shutting down: %v\n", instance.IsShuttingDown())
|
||||||
|
fmt.Printf("In-flight requests: %d\n", instance.InFlightRequests())
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
mgr.StopAll()
|
||||||
|
|
||||||
|
// Wait for complete shutdown
|
||||||
|
instance.Wait()
|
||||||
|
}
|
||||||
137
pkg/server/interfaces.go
Normal file
137
pkg/server/interfaces.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds the configuration for a single web server instance.
|
||||||
|
type Config struct {
|
||||||
|
Name string
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Description string
|
||||||
|
|
||||||
|
// Handler is the http.Handler (e.g., a router) to be served.
|
||||||
|
Handler http.Handler
|
||||||
|
|
||||||
|
// GZIP compression support
|
||||||
|
GZIP bool
|
||||||
|
|
||||||
|
// TLS/HTTPS configuration options (mutually exclusive)
|
||||||
|
// Option 1: Provide certificate and key files directly
|
||||||
|
SSLCert string
|
||||||
|
SSLKey string
|
||||||
|
|
||||||
|
// Option 2: Use self-signed certificate (for development/testing)
|
||||||
|
// Generates a self-signed certificate automatically if no SSLCert/SSLKey provided
|
||||||
|
SelfSignedSSL bool
|
||||||
|
|
||||||
|
// Option 3: Use Let's Encrypt / Certbot for automatic TLS
|
||||||
|
// AutoTLS enables automatic certificate management via Let's Encrypt
|
||||||
|
AutoTLS bool
|
||||||
|
// AutoTLSDomains specifies the domains for Let's Encrypt certificates
|
||||||
|
AutoTLSDomains []string
|
||||||
|
// AutoTLSCacheDir specifies where to cache certificates (default: "./certs-cache")
|
||||||
|
AutoTLSCacheDir string
|
||||||
|
// AutoTLSEmail is the email for Let's Encrypt registration (optional but recommended)
|
||||||
|
AutoTLSEmail string
|
||||||
|
|
||||||
|
// Graceful shutdown configuration
|
||||||
|
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
||||||
|
// Default: 30 seconds
|
||||||
|
ShutdownTimeout time.Duration
|
||||||
|
|
||||||
|
// DrainTimeout is the time to wait for in-flight requests to complete
|
||||||
|
// before forcing shutdown. Default: 25 seconds
|
||||||
|
DrainTimeout time.Duration
|
||||||
|
|
||||||
|
// ReadTimeout is the maximum duration for reading the entire request
|
||||||
|
// Default: 15 seconds
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
|
||||||
|
// WriteTimeout is the maximum duration before timing out writes of the response
|
||||||
|
// Default: 15 seconds
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
|
// IdleTimeout is the maximum amount of time to wait for the next request
|
||||||
|
// Default: 60 seconds
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instance defines the interface for a single server instance.
|
||||||
|
// It abstracts the underlying http.Server, allowing for easier management and testing.
|
||||||
|
type Instance interface {
|
||||||
|
// Start begins serving requests. This method should be non-blocking and
|
||||||
|
// run the server in a separate goroutine.
|
||||||
|
Start() error
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server without interrupting any active connections.
|
||||||
|
// It accepts a context to allow for a timeout.
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
|
// Addr returns the network address the server is listening on.
|
||||||
|
Addr() string
|
||||||
|
|
||||||
|
// Name returns the server instance name.
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// HealthCheckHandler returns a handler that responds to health checks.
|
||||||
|
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down.
|
||||||
|
HealthCheckHandler() http.HandlerFunc
|
||||||
|
|
||||||
|
// ReadinessHandler returns a handler for readiness checks.
|
||||||
|
// Includes in-flight request count.
|
||||||
|
ReadinessHandler() http.HandlerFunc
|
||||||
|
|
||||||
|
// InFlightRequests returns the current number of in-flight requests.
|
||||||
|
InFlightRequests() int64
|
||||||
|
|
||||||
|
// IsShuttingDown returns true if the server is shutting down.
|
||||||
|
IsShuttingDown() bool
|
||||||
|
|
||||||
|
// Wait blocks until shutdown is complete.
|
||||||
|
Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager defines the interface for a server manager.
|
||||||
|
// It is responsible for managing the lifecycle of multiple server instances.
|
||||||
|
type Manager interface {
|
||||||
|
// Add registers a new server instance based on the provided configuration.
|
||||||
|
// The server is not started until StartAll or Start is called on the instance.
|
||||||
|
Add(cfg Config) (Instance, error)
|
||||||
|
|
||||||
|
// Get returns a server instance by its name.
|
||||||
|
Get(name string) (Instance, error)
|
||||||
|
|
||||||
|
// Remove stops and removes a server instance by its name.
|
||||||
|
Remove(name string) error
|
||||||
|
|
||||||
|
// StartAll starts all registered server instances that are not already running.
|
||||||
|
StartAll() error
|
||||||
|
|
||||||
|
// StopAll gracefully shuts down all running server instances.
|
||||||
|
// Executes shutdown callbacks and drains in-flight requests.
|
||||||
|
StopAll() error
|
||||||
|
|
||||||
|
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||||
|
StopAllWithContext(ctx context.Context) error
|
||||||
|
|
||||||
|
// RestartAll gracefully restarts all running server instances.
|
||||||
|
RestartAll() error
|
||||||
|
|
||||||
|
// List returns all registered server instances.
|
||||||
|
List() []Instance
|
||||||
|
|
||||||
|
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||||
|
// It handles SIGINT and SIGTERM signals and performs graceful shutdown with callbacks.
|
||||||
|
ServeWithGracefulShutdown() error
|
||||||
|
|
||||||
|
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||||
|
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
||||||
|
RegisterShutdownCallback(cb ShutdownCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownCallback is a function called during graceful shutdown.
|
||||||
|
type ShutdownCallback func(context.Context) error
|
||||||
572
pkg/server/manager.go
Normal file
572
pkg/server/manager.go
Normal file
@@ -0,0 +1,572 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||||
|
"github.com/klauspost/compress/gzhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// gracefulServer wraps http.Server with graceful shutdown capabilities (internal type)
|
||||||
|
type gracefulServer struct {
|
||||||
|
server *http.Server
|
||||||
|
shutdownTimeout time.Duration
|
||||||
|
drainTimeout time.Duration
|
||||||
|
inFlightRequests atomic.Int64
|
||||||
|
isShuttingDown atomic.Bool
|
||||||
|
shutdownOnce sync.Once
|
||||||
|
shutdownComplete chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
||||||
|
func (gs *gracefulServer) trackRequestsMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if shutting down
|
||||||
|
if gs.isShuttingDown.Load() {
|
||||||
|
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment in-flight counter
|
||||||
|
gs.inFlightRequests.Add(1)
|
||||||
|
defer gs.inFlightRequests.Add(-1)
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown performs graceful shutdown with request draining
|
||||||
|
func (gs *gracefulServer) shutdown(ctx context.Context) error {
|
||||||
|
var shutdownErr error
|
||||||
|
|
||||||
|
gs.shutdownOnce.Do(func() {
|
||||||
|
logger.Info("Starting graceful shutdown...")
|
||||||
|
|
||||||
|
// Mark as shutting down (new requests will be rejected)
|
||||||
|
gs.isShuttingDown.Store(true)
|
||||||
|
|
||||||
|
// Create context with timeout
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Wait for in-flight requests to complete (with drain timeout)
|
||||||
|
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
||||||
|
defer drainCancel()
|
||||||
|
|
||||||
|
shutdownErr = gs.drainRequests(drainCtx)
|
||||||
|
if shutdownErr != nil {
|
||||||
|
logger.Error("Error draining requests: %v", shutdownErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the server
|
||||||
|
logger.Info("Shutting down HTTP server...")
|
||||||
|
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
logger.Error("Error shutting down server: %v", err)
|
||||||
|
if shutdownErr == nil {
|
||||||
|
shutdownErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Graceful shutdown complete")
|
||||||
|
close(gs.shutdownComplete)
|
||||||
|
})
|
||||||
|
|
||||||
|
return shutdownErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainRequests waits for in-flight requests to complete
|
||||||
|
func (gs *gracefulServer) drainRequests(ctx context.Context) error {
|
||||||
|
ticker := time.NewTicker(100 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
for {
|
||||||
|
inFlight := gs.inFlightRequests.Load()
|
||||||
|
|
||||||
|
if inFlight == 0 {
|
||||||
|
logger.Info("All requests drained in %v", time.Since(startTime))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
||||||
|
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
||||||
|
case <-ticker.C:
|
||||||
|
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inFlightRequests returns the current number of in-flight requests
|
||||||
|
func (gs *gracefulServer) inFlightRequestsCount() int64 {
|
||||||
|
return gs.inFlightRequests.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isShutdown returns true if the server is shutting down
|
||||||
|
func (gs *gracefulServer) isShutdown() bool {
|
||||||
|
return gs.isShuttingDown.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait blocks until shutdown is complete
|
||||||
|
func (gs *gracefulServer) wait() {
|
||||||
|
<-gs.shutdownComplete
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthCheckHandler returns a handler that responds to health checks
|
||||||
|
func (gs *gracefulServer) healthCheckHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if gs.isShutdown() {
|
||||||
|
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write health check response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readinessHandler returns a handler for readiness checks
|
||||||
|
func (gs *gracefulServer) readinessHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if gs.isShutdown() {
|
||||||
|
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inFlight := gs.inFlightRequestsCount()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverManager manages a collection of server instances with graceful shutdown support.
|
||||||
|
type serverManager struct {
|
||||||
|
instances map[string]Instance
|
||||||
|
mu sync.RWMutex
|
||||||
|
shutdownCallbacks []ShutdownCallback
|
||||||
|
callbacksMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new server manager.
|
||||||
|
func NewManager() Manager {
|
||||||
|
return &serverManager{
|
||||||
|
instances: make(map[string]Instance),
|
||||||
|
shutdownCallbacks: make([]ShutdownCallback, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add registers a new server instance.
|
||||||
|
func (sm *serverManager) Add(cfg Config) (Instance, error) {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
if cfg.Name == "" {
|
||||||
|
return nil, fmt.Errorf("server name cannot be empty")
|
||||||
|
}
|
||||||
|
if _, exists := sm.instances[cfg.Name]; exists {
|
||||||
|
return nil, fmt.Errorf("server with name '%s' already exists", cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
instance, err := newInstance(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.instances[cfg.Name] = instance
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a server instance by its name.
|
||||||
|
func (sm *serverManager) Get(name string) (Instance, error) {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
instance, exists := sm.instances[name]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("server with name '%s' not found", name)
|
||||||
|
}
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove stops and removes a server instance by its name.
|
||||||
|
func (sm *serverManager) Remove(name string) error {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
instance, exists := sm.instances[name]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("server with name '%s' not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the server if it's running
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := instance.Stop(ctx); err != nil {
|
||||||
|
logger.Warn("Failed to gracefully stop server '%s' on remove: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sm.instances, name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartAll starts all registered server instances.
|
||||||
|
func (sm *serverManager) StartAll() error {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
var startErrors []error
|
||||||
|
for name, instance := range sm.instances {
|
||||||
|
if err := instance.Start(); err != nil {
|
||||||
|
startErrors = append(startErrors, fmt.Errorf("failed to start server '%s': %w", name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(startErrors) > 0 {
|
||||||
|
return fmt.Errorf("encountered errors while starting servers: %v", startErrors)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopAll gracefully shuts down all running server instances.
|
||||||
|
func (sm *serverManager) StopAll() error {
|
||||||
|
return sm.StopAllWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopAllWithContext gracefully shuts down all running server instances with a context.
|
||||||
|
func (sm *serverManager) StopAllWithContext(ctx context.Context) error {
|
||||||
|
sm.mu.RLock()
|
||||||
|
instancesToStop := make([]Instance, 0, len(sm.instances))
|
||||||
|
for _, instance := range sm.instances {
|
||||||
|
instancesToStop = append(instancesToStop, instance)
|
||||||
|
}
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
|
logger.Info("Shutting down all servers...")
|
||||||
|
|
||||||
|
// Execute shutdown callbacks first
|
||||||
|
sm.callbacksMu.Lock()
|
||||||
|
callbacks := make([]ShutdownCallback, len(sm.shutdownCallbacks))
|
||||||
|
copy(callbacks, sm.shutdownCallbacks)
|
||||||
|
sm.callbacksMu.Unlock()
|
||||||
|
|
||||||
|
if len(callbacks) > 0 {
|
||||||
|
logger.Info("Executing %d shutdown callbacks...", len(callbacks))
|
||||||
|
for i, cb := range callbacks {
|
||||||
|
if err := cb(ctx); err != nil {
|
||||||
|
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop all instances in parallel
|
||||||
|
var shutdownErrors []error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var errorsMu sync.Mutex
|
||||||
|
|
||||||
|
for _, instance := range instancesToStop {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(inst Instance) {
|
||||||
|
defer wg.Done()
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := inst.Stop(shutdownCtx); err != nil {
|
||||||
|
errorsMu.Lock()
|
||||||
|
shutdownErrors = append(shutdownErrors, fmt.Errorf("failed to stop server '%s': %w", inst.Name(), err))
|
||||||
|
errorsMu.Unlock()
|
||||||
|
}
|
||||||
|
}(instance)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if len(shutdownErrors) > 0 {
|
||||||
|
return fmt.Errorf("encountered errors while stopping servers: %v", shutdownErrors)
|
||||||
|
}
|
||||||
|
logger.Info("All servers stopped gracefully.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestartAll gracefully restarts all running server instances.
|
||||||
|
func (sm *serverManager) RestartAll() error {
|
||||||
|
logger.Info("Restarting all servers...")
|
||||||
|
if err := sm.StopAll(); err != nil {
|
||||||
|
return fmt.Errorf("failed to stop servers during restart: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give ports time to be released
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
if err := sm.StartAll(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start servers during restart: %w", err)
|
||||||
|
}
|
||||||
|
logger.Info("All servers restarted successfully.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns all registered server instances.
|
||||||
|
func (sm *serverManager) List() []Instance {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
instances := make([]Instance, 0, len(sm.instances))
|
||||||
|
for _, instance := range sm.instances {
|
||||||
|
instances = append(instances, instance)
|
||||||
|
}
|
||||||
|
return instances
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterShutdownCallback registers a callback to be called during shutdown.
|
||||||
|
func (sm *serverManager) RegisterShutdownCallback(cb ShutdownCallback) {
|
||||||
|
sm.callbacksMu.Lock()
|
||||||
|
defer sm.callbacksMu.Unlock()
|
||||||
|
sm.shutdownCallbacks = append(sm.shutdownCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeWithGracefulShutdown starts all servers and blocks until a shutdown signal is received.
|
||||||
|
func (sm *serverManager) ServeWithGracefulShutdown() error {
|
||||||
|
// Start all servers
|
||||||
|
if err := sm.StartAll(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start servers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("All servers started. Waiting for shutdown signal...")
|
||||||
|
|
||||||
|
// Wait for interrupt signal
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
|
||||||
|
sig := <-sigChan
|
||||||
|
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
||||||
|
|
||||||
|
// Create context with timeout for shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return sm.StopAllWithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverInstance is a concrete implementation of the Instance interface.
|
||||||
|
// It wraps gracefulServer to provide graceful shutdown capabilities.
|
||||||
|
type serverInstance struct {
|
||||||
|
cfg Config
|
||||||
|
gracefulServer *gracefulServer
|
||||||
|
certFile string // Path to certificate file (may be temporary for self-signed)
|
||||||
|
keyFile string // Path to key file (may be temporary for self-signed)
|
||||||
|
mu sync.RWMutex
|
||||||
|
running bool
|
||||||
|
serverErr chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
// newInstance creates a new, unstarted server instance from a config.
|
||||||
|
func newInstance(cfg Config) (*serverInstance, error) {
|
||||||
|
if cfg.Handler == nil {
|
||||||
|
return nil, fmt.Errorf("handler cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default timeouts
|
||||||
|
if cfg.ShutdownTimeout == 0 {
|
||||||
|
cfg.ShutdownTimeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.DrainTimeout == 0 {
|
||||||
|
cfg.DrainTimeout = 25 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.ReadTimeout == 0 {
|
||||||
|
cfg.ReadTimeout = 15 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.WriteTimeout == 0 {
|
||||||
|
cfg.WriteTimeout = 15 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.IdleTimeout == 0 {
|
||||||
|
cfg.IdleTimeout = 60 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||||
|
var handler http.Handler = cfg.Handler
|
||||||
|
|
||||||
|
// Wrap with GZIP handler if enabled
|
||||||
|
if cfg.GZIP {
|
||||||
|
gz, err := gzhttp.NewWrapper()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create GZIP wrapper: %w", err)
|
||||||
|
}
|
||||||
|
handler = gz(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap with the panic recovery middleware
|
||||||
|
handler = middleware.PanicRecovery(handler)
|
||||||
|
|
||||||
|
// Configure TLS if any TLS option is enabled
|
||||||
|
tlsConfig, certFile, keyFile, err := configureTLS(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to configure TLS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create gracefulServer
|
||||||
|
gracefulSrv := &gracefulServer{
|
||||||
|
server: &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: cfg.ReadTimeout,
|
||||||
|
WriteTimeout: cfg.WriteTimeout,
|
||||||
|
IdleTimeout: cfg.IdleTimeout,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
},
|
||||||
|
shutdownTimeout: cfg.ShutdownTimeout,
|
||||||
|
drainTimeout: cfg.DrainTimeout,
|
||||||
|
shutdownComplete: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serverInstance{
|
||||||
|
cfg: cfg,
|
||||||
|
gracefulServer: gracefulSrv,
|
||||||
|
certFile: certFile,
|
||||||
|
keyFile: keyFile,
|
||||||
|
serverErr: make(chan error, 1),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins serving requests in a new goroutine.
|
||||||
|
func (s *serverInstance) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("server '%s' is already running", s.cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if we're using TLS
|
||||||
|
useTLS := s.cfg.SSLCert != "" || s.cfg.SSLKey != "" || s.cfg.SelfSignedSSL || s.cfg.AutoTLS
|
||||||
|
|
||||||
|
// Wrap handler with request tracking
|
||||||
|
s.gracefulServer.server.Handler = s.gracefulServer.trackRequestsMiddleware(s.gracefulServer.server.Handler)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.running = false
|
||||||
|
s.mu.Unlock()
|
||||||
|
logger.Info("Server '%s' stopped.", s.cfg.Name)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
protocol := "HTTP"
|
||||||
|
|
||||||
|
if useTLS {
|
||||||
|
protocol = "HTTPS"
|
||||||
|
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||||
|
|
||||||
|
// For AutoTLS, we need to use a TLS listener
|
||||||
|
if s.cfg.AutoTLS {
|
||||||
|
// Create listener
|
||||||
|
ln, lnErr := net.Listen("tcp", s.gracefulServer.server.Addr)
|
||||||
|
if lnErr != nil {
|
||||||
|
err = fmt.Errorf("failed to create listener: %w", lnErr)
|
||||||
|
} else {
|
||||||
|
// Wrap with TLS
|
||||||
|
tlsListener := tls.NewListener(ln, s.gracefulServer.server.TLSConfig)
|
||||||
|
err = s.gracefulServer.server.Serve(tlsListener)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use certificate files (regular SSL or self-signed)
|
||||||
|
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||||
|
err = s.gracefulServer.server.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the server stopped for a reason other than a graceful shutdown, log and report the error.
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
logger.Error("Server '%s' failed: %v", s.cfg.Name, err)
|
||||||
|
select {
|
||||||
|
case s.serverErr <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.running = true
|
||||||
|
// A small delay to allow the goroutine to start and potentially fail on binding.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check if the server failed to start
|
||||||
|
select {
|
||||||
|
case err := <-s.serverErr:
|
||||||
|
s.running = false
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server.
|
||||||
|
func (s *serverInstance) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.running {
|
||||||
|
return nil // Already stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Gracefully shutting down server '%s'...", s.cfg.Name)
|
||||||
|
err := s.gracefulServer.shutdown(ctx)
|
||||||
|
if err == nil {
|
||||||
|
s.running = false
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the network address the server is listening on.
|
||||||
|
func (s *serverInstance) Addr() string {
|
||||||
|
return s.gracefulServer.server.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the server instance name.
|
||||||
|
func (s *serverInstance) Name() string {
|
||||||
|
return s.cfg.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheckHandler returns a handler that responds to health checks.
|
||||||
|
func (s *serverInstance) HealthCheckHandler() http.HandlerFunc {
|
||||||
|
return s.gracefulServer.healthCheckHandler()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadinessHandler returns a handler for readiness checks.
|
||||||
|
func (s *serverInstance) ReadinessHandler() http.HandlerFunc {
|
||||||
|
return s.gracefulServer.readinessHandler()
|
||||||
|
}
|
||||||
|
|
||||||
|
// InFlightRequests returns the current number of in-flight requests.
|
||||||
|
func (s *serverInstance) InFlightRequests() int64 {
|
||||||
|
return s.gracefulServer.inFlightRequestsCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsShuttingDown returns true if the server is shutting down.
|
||||||
|
func (s *serverInstance) IsShuttingDown() bool {
|
||||||
|
return s.gracefulServer.isShutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until shutdown is complete.
|
||||||
|
func (s *serverInstance) Wait() {
|
||||||
|
s.gracefulServer.wait()
|
||||||
|
}
|
||||||
328
pkg/server/manager_test.go
Normal file
328
pkg/server/manager_test.go
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getFreePort asks the kernel for a free open port that is ready to use.
|
||||||
|
func getFreePort(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
l, err := net.ListenTCP("tcp", addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer l.Close()
|
||||||
|
return l.Addr().(*net.TCPAddr).Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerManagerLifecycle(t *testing.T) {
|
||||||
|
// Initialize logger for test output
|
||||||
|
logger.Init(true)
|
||||||
|
|
||||||
|
// Create a new server manager
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
// Define a simple test handler
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("Hello, World!"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get a free port for the server to listen on to avoid conflicts
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
|
||||||
|
// Add a new server configuration
|
||||||
|
serverConfig := Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: testHandler,
|
||||||
|
}
|
||||||
|
instance, err := sm.Add(serverConfig)
|
||||||
|
require.NoError(t, err, "should be able to add a new server")
|
||||||
|
require.NotNil(t, instance, "added instance should not be nil")
|
||||||
|
|
||||||
|
// --- Test StartAll ---
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err, "StartAll should not return an error")
|
||||||
|
|
||||||
|
// Give the server a moment to start up
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// --- Verify Server is Running ---
|
||||||
|
client := &http.Client{Timeout: 2 * time.Second}
|
||||||
|
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
require.NoError(t, err, "should be able to make a request to the running server")
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode, "expected status OK from the test server")
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.Equal(t, "Hello, World!", string(body), "response body should match expected value")
|
||||||
|
|
||||||
|
// --- Test Get ---
|
||||||
|
retrievedInstance, err := sm.Get("TestServer")
|
||||||
|
require.NoError(t, err, "should be able to get server by name")
|
||||||
|
assert.Equal(t, instance.Addr(), retrievedInstance.Addr(), "retrieved instance should be the same")
|
||||||
|
|
||||||
|
// --- Test List ---
|
||||||
|
instanceList := sm.List()
|
||||||
|
require.Len(t, instanceList, 1, "list should contain one instance")
|
||||||
|
assert.Equal(t, instance.Addr(), instanceList[0].Addr(), "listed instance should be the same")
|
||||||
|
|
||||||
|
// --- Test StopAll ---
|
||||||
|
err = sm.StopAll()
|
||||||
|
require.NoError(t, err, "StopAll should not return an error")
|
||||||
|
|
||||||
|
// Give the server a moment to shut down
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// --- Verify Server is Stopped ---
|
||||||
|
_, err = client.Get(url)
|
||||||
|
require.Error(t, err, "should not be able to make a request to a stopped server")
|
||||||
|
|
||||||
|
// --- Test Remove ---
|
||||||
|
err = sm.Remove("TestServer")
|
||||||
|
require.NoError(t, err, "should be able to remove a server")
|
||||||
|
|
||||||
|
_, err = sm.Get("TestServer")
|
||||||
|
require.Error(t, err, "should not be able to get a removed server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerErrorCases(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
|
||||||
|
// --- Test Add Duplicate Name ---
|
||||||
|
config1 := Config{Name: "Duplicate", Host: "localhost", Port: testPort, Handler: http.NewServeMux()}
|
||||||
|
_, err := sm.Add(config1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config2 := Config{Name: "Duplicate", Host: "localhost", Port: getFreePort(t), Handler: http.NewServeMux()}
|
||||||
|
_, err = sm.Add(config2)
|
||||||
|
require.Error(t, err, "should not be able to add a server with a duplicate name")
|
||||||
|
|
||||||
|
// --- Test Get Non-existent ---
|
||||||
|
_, err = sm.Get("NonExistent")
|
||||||
|
require.Error(t, err, "should get an error for a non-existent server")
|
||||||
|
|
||||||
|
// --- Test Add with Nil Handler ---
|
||||||
|
config3 := Config{Name: "NilHandler", Host: "localhost", Port: getFreePort(t), Handler: nil}
|
||||||
|
_, err = sm.Add(config3)
|
||||||
|
require.Error(t, err, "should not be able to add a server with a nil handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
requestsHandled := 0
|
||||||
|
var requestsMu sync.Mutex
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestsMu.Lock()
|
||||||
|
requestsHandled++
|
||||||
|
requestsMu.Unlock()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
instance, err := sm.Add(Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: handler,
|
||||||
|
DrainTimeout: 2 * time.Second,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Give server time to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send some concurrent requests
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit for requests to start
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check in-flight requests
|
||||||
|
inFlight := instance.InFlightRequests()
|
||||||
|
assert.Greater(t, inFlight, int64(0), "Should have in-flight requests")
|
||||||
|
|
||||||
|
// Stop the server
|
||||||
|
err = sm.StopAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for all requests to complete
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Verify all requests were handled
|
||||||
|
requestsMu.Lock()
|
||||||
|
handled := requestsHandled
|
||||||
|
requestsMu.Unlock()
|
||||||
|
assert.GreaterOrEqual(t, handled, 1, "At least some requests should have been handled")
|
||||||
|
|
||||||
|
// Verify no in-flight requests
|
||||||
|
assert.Equal(t, int64(0), instance.InFlightRequests(), "Should have no in-flight requests after shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthAndReadinessEndpoints(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
|
||||||
|
instance, err := sm.Add(Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: mux,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add health and readiness endpoints
|
||||||
|
mux.HandleFunc("/health", instance.HealthCheckHandler())
|
||||||
|
mux.HandleFunc("/ready", instance.ReadinessHandler())
|
||||||
|
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 2 * time.Second}
|
||||||
|
baseURL := fmt.Sprintf("http://localhost:%d", testPort)
|
||||||
|
|
||||||
|
// Test health endpoint
|
||||||
|
resp, err := client.Get(baseURL + "/health")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.Contains(t, string(body), "healthy")
|
||||||
|
|
||||||
|
// Test readiness endpoint
|
||||||
|
resp, err = client.Get(baseURL + "/ready")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
body, _ = io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.Contains(t, string(body), "ready")
|
||||||
|
assert.Contains(t, string(body), "in_flight_requests")
|
||||||
|
|
||||||
|
// Stop the server
|
||||||
|
sm.StopAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestRejectionDuringShutdown(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
_, err := sm.Add(Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: handler,
|
||||||
|
DrainTimeout: 1 * time.Second,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Start shutdown in background
|
||||||
|
go func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
sm.StopAll()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give shutdown time to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Try to make a request after shutdown started
|
||||||
|
client := &http.Client{Timeout: 2 * time.Second}
|
||||||
|
url := fmt.Sprintf("http://localhost:%d", testPort)
|
||||||
|
resp, err := client.Get(url)
|
||||||
|
|
||||||
|
// The request should either fail (connection refused) or get 503
|
||||||
|
if err == nil {
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Should get 503 during shutdown")
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownCallbacks(t *testing.T) {
|
||||||
|
logger.Init(true)
|
||||||
|
sm := NewManager()
|
||||||
|
|
||||||
|
callbackExecuted := false
|
||||||
|
var callbackMu sync.Mutex
|
||||||
|
|
||||||
|
sm.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
callbackMu.Lock()
|
||||||
|
callbackExecuted = true
|
||||||
|
callbackMu.Unlock()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
testPort := getFreePort(t)
|
||||||
|
_, err := sm.Add(Config{
|
||||||
|
Name: "TestServer",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: testPort,
|
||||||
|
Handler: http.NewServeMux(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sm.StartAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
err = sm.StopAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
callbackMu.Lock()
|
||||||
|
executed := callbackExecuted
|
||||||
|
callbackMu.Unlock()
|
||||||
|
|
||||||
|
assert.True(t, executed, "Shutdown callback should have been executed")
|
||||||
|
}
|
||||||
@@ -1,296 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GracefulServer wraps http.Server with graceful shutdown capabilities
|
|
||||||
type GracefulServer struct {
|
|
||||||
server *http.Server
|
|
||||||
shutdownTimeout time.Duration
|
|
||||||
drainTimeout time.Duration
|
|
||||||
inFlightRequests atomic.Int64
|
|
||||||
isShuttingDown atomic.Bool
|
|
||||||
shutdownOnce sync.Once
|
|
||||||
shutdownComplete chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config holds configuration for the graceful server
|
|
||||||
type Config struct {
|
|
||||||
// Addr is the server address (e.g., ":8080")
|
|
||||||
Addr string
|
|
||||||
|
|
||||||
// Handler is the HTTP handler
|
|
||||||
Handler http.Handler
|
|
||||||
|
|
||||||
// ShutdownTimeout is the maximum time to wait for graceful shutdown
|
|
||||||
// Default: 30 seconds
|
|
||||||
ShutdownTimeout time.Duration
|
|
||||||
|
|
||||||
// DrainTimeout is the time to wait for in-flight requests to complete
|
|
||||||
// before forcing shutdown. Default: 25 seconds
|
|
||||||
DrainTimeout time.Duration
|
|
||||||
|
|
||||||
// ReadTimeout is the maximum duration for reading the entire request
|
|
||||||
ReadTimeout time.Duration
|
|
||||||
|
|
||||||
// WriteTimeout is the maximum duration before timing out writes of the response
|
|
||||||
WriteTimeout time.Duration
|
|
||||||
|
|
||||||
// IdleTimeout is the maximum amount of time to wait for the next request
|
|
||||||
IdleTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGracefulServer creates a new graceful server
|
|
||||||
func NewGracefulServer(config Config) *GracefulServer {
|
|
||||||
if config.ShutdownTimeout == 0 {
|
|
||||||
config.ShutdownTimeout = 30 * time.Second
|
|
||||||
}
|
|
||||||
if config.DrainTimeout == 0 {
|
|
||||||
config.DrainTimeout = 25 * time.Second
|
|
||||||
}
|
|
||||||
if config.ReadTimeout == 0 {
|
|
||||||
config.ReadTimeout = 10 * time.Second
|
|
||||||
}
|
|
||||||
if config.WriteTimeout == 0 {
|
|
||||||
config.WriteTimeout = 10 * time.Second
|
|
||||||
}
|
|
||||||
if config.IdleTimeout == 0 {
|
|
||||||
config.IdleTimeout = 120 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
gs := &GracefulServer{
|
|
||||||
server: &http.Server{
|
|
||||||
Addr: config.Addr,
|
|
||||||
Handler: config.Handler,
|
|
||||||
ReadTimeout: config.ReadTimeout,
|
|
||||||
WriteTimeout: config.WriteTimeout,
|
|
||||||
IdleTimeout: config.IdleTimeout,
|
|
||||||
},
|
|
||||||
shutdownTimeout: config.ShutdownTimeout,
|
|
||||||
drainTimeout: config.DrainTimeout,
|
|
||||||
shutdownComplete: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
return gs
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrackRequestsMiddleware tracks in-flight requests and blocks new requests during shutdown
|
|
||||||
func (gs *GracefulServer) TrackRequestsMiddleware(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Check if shutting down
|
|
||||||
if gs.isShuttingDown.Load() {
|
|
||||||
http.Error(w, `{"error":"service_unavailable","message":"Server is shutting down"}`, http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increment in-flight counter
|
|
||||||
gs.inFlightRequests.Add(1)
|
|
||||||
defer gs.inFlightRequests.Add(-1)
|
|
||||||
|
|
||||||
// Serve the request
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServe starts the server and handles graceful shutdown
|
|
||||||
func (gs *GracefulServer) ListenAndServe() error {
|
|
||||||
// Wrap handler with request tracking
|
|
||||||
gs.server.Handler = gs.TrackRequestsMiddleware(gs.server.Handler)
|
|
||||||
|
|
||||||
// Start server in goroutine
|
|
||||||
serverErr := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
logger.Info("Starting server on %s", gs.server.Addr)
|
|
||||||
if err := gs.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
serverErr <- err
|
|
||||||
}
|
|
||||||
close(serverErr)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for interrupt signal
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-serverErr:
|
|
||||||
return err
|
|
||||||
case sig := <-sigChan:
|
|
||||||
logger.Info("Received signal: %v, initiating graceful shutdown", sig)
|
|
||||||
return gs.Shutdown(context.Background())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown performs graceful shutdown with request draining
|
|
||||||
func (gs *GracefulServer) Shutdown(ctx context.Context) error {
|
|
||||||
var shutdownErr error
|
|
||||||
|
|
||||||
gs.shutdownOnce.Do(func() {
|
|
||||||
logger.Info("Starting graceful shutdown...")
|
|
||||||
|
|
||||||
// Mark as shutting down (new requests will be rejected)
|
|
||||||
gs.isShuttingDown.Store(true)
|
|
||||||
|
|
||||||
// Create context with timeout
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, gs.shutdownTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Wait for in-flight requests to complete (with drain timeout)
|
|
||||||
drainCtx, drainCancel := context.WithTimeout(shutdownCtx, gs.drainTimeout)
|
|
||||||
defer drainCancel()
|
|
||||||
|
|
||||||
shutdownErr = gs.drainRequests(drainCtx)
|
|
||||||
if shutdownErr != nil {
|
|
||||||
logger.Error("Error draining requests: %v", shutdownErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown the server
|
|
||||||
logger.Info("Shutting down HTTP server...")
|
|
||||||
if err := gs.server.Shutdown(shutdownCtx); err != nil {
|
|
||||||
logger.Error("Error shutting down server: %v", err)
|
|
||||||
if shutdownErr == nil {
|
|
||||||
shutdownErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Graceful shutdown complete")
|
|
||||||
close(gs.shutdownComplete)
|
|
||||||
})
|
|
||||||
|
|
||||||
return shutdownErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// drainRequests waits for in-flight requests to complete
|
|
||||||
func (gs *GracefulServer) drainRequests(ctx context.Context) error {
|
|
||||||
ticker := time.NewTicker(100 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
for {
|
|
||||||
inFlight := gs.inFlightRequests.Load()
|
|
||||||
|
|
||||||
if inFlight == 0 {
|
|
||||||
logger.Info("All requests drained in %v", time.Since(startTime))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Warn("Drain timeout exceeded with %d requests still in flight", inFlight)
|
|
||||||
return fmt.Errorf("drain timeout exceeded: %d requests still in flight", inFlight)
|
|
||||||
case <-ticker.C:
|
|
||||||
logger.Debug("Waiting for %d in-flight requests to complete...", inFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InFlightRequests returns the current number of in-flight requests
|
|
||||||
func (gs *GracefulServer) InFlightRequests() int64 {
|
|
||||||
return gs.inFlightRequests.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsShuttingDown returns true if the server is shutting down
|
|
||||||
func (gs *GracefulServer) IsShuttingDown() bool {
|
|
||||||
return gs.isShuttingDown.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait blocks until shutdown is complete
|
|
||||||
func (gs *GracefulServer) Wait() {
|
|
||||||
<-gs.shutdownComplete
|
|
||||||
}
|
|
||||||
|
|
||||||
// HealthCheckHandler returns a handler that responds to health checks
|
|
||||||
// Returns 200 OK when healthy, 503 Service Unavailable when shutting down
|
|
||||||
func (gs *GracefulServer) HealthCheckHandler() http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if gs.IsShuttingDown() {
|
|
||||||
http.Error(w, `{"status":"shutting_down"}`, http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, err := w.Write([]byte(`{"status":"healthy"}`))
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to write. %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadinessHandler returns a handler for readiness checks
|
|
||||||
// Includes in-flight request count
|
|
||||||
func (gs *GracefulServer) ReadinessHandler() http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if gs.IsShuttingDown() {
|
|
||||||
http.Error(w, `{"ready":false,"reason":"shutting_down"}`, http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
inFlight := gs.InFlightRequests()
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
fmt.Fprintf(w, `{"ready":true,"in_flight_requests":%d}`, inFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShutdownCallback is a function called during shutdown
|
|
||||||
type ShutdownCallback func(context.Context) error
|
|
||||||
|
|
||||||
// shutdownCallbacks stores registered shutdown callbacks
|
|
||||||
var (
|
|
||||||
shutdownCallbacks []ShutdownCallback
|
|
||||||
shutdownCallbacksMu sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
// RegisterShutdownCallback registers a callback to be called during shutdown
|
|
||||||
// Useful for cleanup tasks like closing database connections, flushing metrics, etc.
|
|
||||||
func RegisterShutdownCallback(cb ShutdownCallback) {
|
|
||||||
shutdownCallbacksMu.Lock()
|
|
||||||
defer shutdownCallbacksMu.Unlock()
|
|
||||||
shutdownCallbacks = append(shutdownCallbacks, cb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeShutdownCallbacks runs all registered shutdown callbacks
|
|
||||||
func executeShutdownCallbacks(ctx context.Context) error {
|
|
||||||
shutdownCallbacksMu.Lock()
|
|
||||||
callbacks := make([]ShutdownCallback, len(shutdownCallbacks))
|
|
||||||
copy(callbacks, shutdownCallbacks)
|
|
||||||
shutdownCallbacksMu.Unlock()
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
for i, cb := range callbacks {
|
|
||||||
logger.Debug("Executing shutdown callback %d/%d", i+1, len(callbacks))
|
|
||||||
if err := cb(ctx); err != nil {
|
|
||||||
logger.Error("Shutdown callback %d failed: %v", i+1, err)
|
|
||||||
errors = append(errors, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return fmt.Errorf("shutdown callbacks failed: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShutdownWithCallbacks performs shutdown and executes all registered callbacks
|
|
||||||
func (gs *GracefulServer) ShutdownWithCallbacks(ctx context.Context) error {
|
|
||||||
// Execute callbacks first
|
|
||||||
if err := executeShutdownCallbacks(ctx); err != nil {
|
|
||||||
logger.Error("Error executing shutdown callbacks: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then shutdown the server
|
|
||||||
return gs.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGracefulServerTrackRequests(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
|
|
||||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
|
||||||
|
|
||||||
// Start some requests
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait a bit for requests to start
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
|
|
||||||
// Check in-flight count
|
|
||||||
inFlight := srv.InFlightRequests()
|
|
||||||
if inFlight == 0 {
|
|
||||||
t.Error("Should have in-flight requests")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all requests to complete
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Check that counter is back to zero
|
|
||||||
inFlight = srv.InFlightRequests()
|
|
||||||
if inFlight != 0 {
|
|
||||||
t.Errorf("In-flight requests should be 0, got %d", inFlight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGracefulServerRejectsRequestsDuringShutdown(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
|
|
||||||
handler := srv.TrackRequestsMiddleware(srv.server.Handler)
|
|
||||||
|
|
||||||
// Mark as shutting down
|
|
||||||
srv.isShuttingDown.Store(true)
|
|
||||||
|
|
||||||
// Try to make a request
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
// Should get 503
|
|
||||||
if w.Code != http.StatusServiceUnavailable {
|
|
||||||
t.Errorf("Expected 503, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthCheckHandler(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
|
||||||
})
|
|
||||||
|
|
||||||
handler := srv.HealthCheckHandler()
|
|
||||||
|
|
||||||
// Healthy
|
|
||||||
t.Run("Healthy", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.Body.String() != `{"status":"healthy"}` {
|
|
||||||
t.Errorf("Unexpected body: %s", w.Body.String())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Shutting down
|
|
||||||
t.Run("ShuttingDown", func(t *testing.T) {
|
|
||||||
srv.isShuttingDown.Store(true)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusServiceUnavailable {
|
|
||||||
t.Errorf("Expected 503, got %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadinessHandler(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
|
||||||
})
|
|
||||||
|
|
||||||
handler := srv.ReadinessHandler()
|
|
||||||
|
|
||||||
// Ready with no in-flight requests
|
|
||||||
t.Run("Ready", func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/ready", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected 200, got %d", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
body := w.Body.String()
|
|
||||||
if body != `{"ready":true,"in_flight_requests":0}` {
|
|
||||||
t.Errorf("Unexpected body: %s", body)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Not ready during shutdown
|
|
||||||
t.Run("NotReady", func(t *testing.T) {
|
|
||||||
srv.isShuttingDown.Store(true)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/ready", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusServiceUnavailable {
|
|
||||||
t.Errorf("Expected 503, got %d", w.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShutdownCallbacks(t *testing.T) {
|
|
||||||
callbackExecuted := false
|
|
||||||
|
|
||||||
RegisterShutdownCallback(func(ctx context.Context) error {
|
|
||||||
callbackExecuted = true
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
err := executeShutdownCallbacks(ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("executeShutdownCallbacks() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !callbackExecuted {
|
|
||||||
t.Error("Shutdown callback was not executed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset for other tests
|
|
||||||
shutdownCallbacks = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDrainRequests(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
|
||||||
DrainTimeout: 1 * time.Second,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Simulate in-flight requests
|
|
||||||
srv.inFlightRequests.Add(3)
|
|
||||||
|
|
||||||
// Start draining in background
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
// Simulate requests completing
|
|
||||||
srv.inFlightRequests.Add(-3)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := srv.drainRequests(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("drainRequests() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if srv.InFlightRequests() != 0 {
|
|
||||||
t.Errorf("In-flight requests should be 0, got %d", srv.InFlightRequests())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDrainRequestsTimeout(t *testing.T) {
|
|
||||||
srv := NewGracefulServer(Config{
|
|
||||||
Addr: ":0",
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
|
||||||
DrainTimeout: 100 * time.Millisecond,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Simulate in-flight requests that don't complete
|
|
||||||
srv.inFlightRequests.Add(5)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := srv.drainRequests(ctx)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("drainRequests() should timeout with error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
srv.inFlightRequests.Add(-5)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetClientIP(t *testing.T) {
|
|
||||||
// This test is in ratelimit_test.go since getClientIP is used by rate limiter
|
|
||||||
// Including here for completeness of server tests
|
|
||||||
}
|
|
||||||
190
pkg/server/tls.go
Normal file
190
pkg/server/tls.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateSelfSignedCert generates a self-signed certificate for the given host.
|
||||||
|
// Returns the certificate and private key in PEM format.
|
||||||
|
func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) {
|
||||||
|
// Generate private key
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate template
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
|
||||||
|
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"ResolveSpec Self-Signed"},
|
||||||
|
CommonName: host,
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add host as DNS name or IP address
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
template.IPAddresses = []net.IP{ip}
|
||||||
|
} else {
|
||||||
|
template.DNSNames = []string{host}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode certificate to PEM
|
||||||
|
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||||
|
|
||||||
|
// Encode private key to PEM
|
||||||
|
privBytes, err := x509.MarshalECPrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||||
|
}
|
||||||
|
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
|
||||||
|
|
||||||
|
return certPEM, keyPEM, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveCertToTempFiles saves certificate and key PEM data to temporary files.
|
||||||
|
// Returns the file paths for the certificate and key.
|
||||||
|
func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err error) {
|
||||||
|
// Create temporary directory
|
||||||
|
tmpDir, err := os.MkdirTemp("", "resolvespec-certs-*")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certFile = filepath.Join(tmpDir, "cert.pem")
|
||||||
|
keyFile = filepath.Join(tmpDir, "key.pem")
|
||||||
|
|
||||||
|
// Write certificate
|
||||||
|
if err := os.WriteFile(certFile, certPEM, 0600); err != nil {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
return "", "", fmt.Errorf("failed to write certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write key
|
||||||
|
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
|
||||||
|
os.RemoveAll(tmpDir)
|
||||||
|
return "", "", fmt.Errorf("failed to write private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certFile, keyFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupAutoTLS configures automatic TLS certificate management using Let's Encrypt.
|
||||||
|
// Returns a TLS config that can be used with http.Server.
|
||||||
|
func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) {
|
||||||
|
if len(domains) == 0 {
|
||||||
|
return nil, fmt.Errorf("at least one domain must be specified for AutoTLS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default cache directory
|
||||||
|
if cacheDir == "" {
|
||||||
|
cacheDir = "./certs-cache"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create cache directory if it doesn't exist
|
||||||
|
if err := os.MkdirAll(cacheDir, 0700); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create certificate cache directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create autocert manager
|
||||||
|
m := &autocert.Manager{
|
||||||
|
Prompt: autocert.AcceptTOS,
|
||||||
|
Cache: autocert.DirCache(cacheDir),
|
||||||
|
HostPolicy: autocert.HostWhitelist(domains...),
|
||||||
|
Email: email,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS config
|
||||||
|
tlsConfig := m.TLSConfig()
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS12
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureTLS configures TLS for the server based on the provided configuration.
|
||||||
|
// Returns the TLS config and certificate/key file paths (if applicable).
|
||||||
|
func configureTLS(cfg Config) (*tls.Config, string, string, error) {
|
||||||
|
// Option 1: Certificate files provided
|
||||||
|
if cfg.SSLCert != "" && cfg.SSLKey != "" {
|
||||||
|
// Validate that files exist
|
||||||
|
if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) {
|
||||||
|
return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) {
|
||||||
|
return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return basic TLS config - cert/key will be loaded by ListenAndServeTLS
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 2: Auto TLS (Let's Encrypt)
|
||||||
|
if cfg.AutoTLS {
|
||||||
|
tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err)
|
||||||
|
}
|
||||||
|
return tlsConfig, "", "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 3: Self-signed certificate
|
||||||
|
if cfg.SelfSignedSSL {
|
||||||
|
host := cfg.Host
|
||||||
|
if host == "" || host == "0.0.0.0" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM, keyPEM, err := generateSelfSignedCert(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certFile, keyFile, err := saveCertToTempFiles(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
return tlsConfig, certFile, keyFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, "", "", nil
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
// Package common provides nullable SQL types with automatic casting and conversion methods.
|
// Package spectypes provides nullable SQL types with automatic casting and conversion methods.
|
||||||
package common
|
package spectypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package common
|
package spectypes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
@@ -465,7 +465,7 @@ func processRequest(ctx context.Context) {
|
|||||||
|
|
||||||
1. **Check collector is running:**
|
1. **Check collector is running:**
|
||||||
```bash
|
```bash
|
||||||
docker-compose ps
|
podman compose ps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Verify endpoint:**
|
2. **Verify endpoint:**
|
||||||
@@ -476,7 +476,7 @@ func processRequest(ctx context.Context) {
|
|||||||
|
|
||||||
3. **Check logs:**
|
3. **Check logs:**
|
||||||
```bash
|
```bash
|
||||||
docker-compose logs otel-collector
|
podman compose logs otel-collector
|
||||||
```
|
```
|
||||||
|
|
||||||
### Disable Tracing
|
### Disable Tracing
|
||||||
|
|||||||
@@ -14,33 +14,33 @@ NC='\033[0m' # No Color
|
|||||||
|
|
||||||
echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n"
|
echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n"
|
||||||
|
|
||||||
# Check if docker-compose is available
|
# Check if podman compose is available
|
||||||
if ! command -v docker-compose &> /dev/null; then
|
if ! command -v podman &> /dev/null; then
|
||||||
echo -e "${RED}Error: docker-compose is not installed${NC}"
|
echo -e "${RED}Error: podman is not installed${NC}"
|
||||||
echo "Please install docker-compose or run PostgreSQL manually"
|
echo "Please install podman or run PostgreSQL manually"
|
||||||
echo "See INTEGRATION_TESTS.md for details"
|
echo "See INTEGRATION_TESTS.md for details"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Clean up any existing containers and networks from previous runs
|
# Clean up any existing containers and networks from previous runs
|
||||||
echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}"
|
echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}"
|
||||||
docker-compose down -v 2>/dev/null || true
|
podman compose down -v 2>/dev/null || true
|
||||||
|
|
||||||
# Start PostgreSQL
|
# Start PostgreSQL
|
||||||
echo -e "${YELLOW}Starting PostgreSQL...${NC}"
|
echo -e "${YELLOW}Starting PostgreSQL...${NC}"
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
|
|
||||||
# Wait for PostgreSQL to be ready
|
# Wait for PostgreSQL to be ready
|
||||||
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
||||||
max_attempts=30
|
max_attempts=30
|
||||||
attempt=0
|
attempt=0
|
||||||
|
|
||||||
while ! docker-compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do
|
while ! podman compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do
|
||||||
attempt=$((attempt + 1))
|
attempt=$((attempt + 1))
|
||||||
if [ $attempt -ge $max_attempts ]; then
|
if [ $attempt -ge $max_attempts ]; then
|
||||||
echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}"
|
echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}"
|
||||||
docker-compose logs postgres-test
|
podman compose logs postgres-test
|
||||||
docker-compose down
|
podman compose down
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
sleep 1
|
sleep 1
|
||||||
@@ -51,8 +51,8 @@ echo -e "\n${GREEN}PostgreSQL is ready!${NC}\n"
|
|||||||
|
|
||||||
# Create test databases
|
# Create test databases
|
||||||
echo -e "${YELLOW}Creating test databases...${NC}"
|
echo -e "${YELLOW}Creating test databases...${NC}"
|
||||||
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists"
|
podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists"
|
||||||
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists"
|
podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists"
|
||||||
echo -e "${GREEN}Test databases ready!${NC}\n"
|
echo -e "${GREEN}Test databases ready!${NC}\n"
|
||||||
|
|
||||||
# Determine which tests to run
|
# Determine which tests to run
|
||||||
@@ -79,6 +79,6 @@ fi
|
|||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}"
|
echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}"
|
||||||
docker-compose down
|
podman compose down
|
||||||
|
|
||||||
exit $EXIT_CODE
|
exit $EXIT_CODE
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ Integration tests validate the full functionality of both `pkg/resolvespec` and
|
|||||||
|
|
||||||
- Go 1.19 or later
|
- Go 1.19 or later
|
||||||
- PostgreSQL 12 or later
|
- PostgreSQL 12 or later
|
||||||
- Docker and Docker Compose (optional, for easy setup)
|
- Podman and Podman Compose (optional, for easy setup)
|
||||||
|
|
||||||
## Quick Start with Docker
|
## Quick Start with Podman
|
||||||
|
|
||||||
### 1. Start PostgreSQL with Docker Compose
|
### 1. Start PostgreSQL with Podman Compose
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
This starts a PostgreSQL container with the following default settings:
|
This starts a PostgreSQL container with the following default settings:
|
||||||
@@ -52,7 +52,7 @@ go test -tags=integration ./pkg/restheadspec -v
|
|||||||
### 3. Stop PostgreSQL
|
### 3. Stop PostgreSQL
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose down
|
podman compose down
|
||||||
```
|
```
|
||||||
|
|
||||||
## Manual PostgreSQL Setup
|
## Manual PostgreSQL Setup
|
||||||
@@ -161,7 +161,7 @@ If you see "connection refused" errors:
|
|||||||
|
|
||||||
1. Check that PostgreSQL is running:
|
1. Check that PostgreSQL is running:
|
||||||
```bash
|
```bash
|
||||||
docker-compose ps
|
podman compose ps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Verify connection parameters:
|
2. Verify connection parameters:
|
||||||
@@ -194,10 +194,10 @@ Each test automatically cleans up its data using `TRUNCATE`. If you need a fresh
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Stop and remove containers (removes data)
|
# Stop and remove containers (removes data)
|
||||||
docker-compose down -v
|
podman compose down -v
|
||||||
|
|
||||||
# Restart
|
# Restart
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
@@ -119,13 +119,13 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
|
|||||||
- PostgreSQL 12+ installed and running
|
- PostgreSQL 12+ installed and running
|
||||||
- Create test databases manually (see below)
|
- Create test databases manually (see below)
|
||||||
|
|
||||||
### Setup with Docker
|
### Setup with Podman
|
||||||
|
|
||||||
1. **Start PostgreSQL**:
|
1. **Start PostgreSQL**:
|
||||||
```bash
|
```bash
|
||||||
make docker-up
|
make docker-up
|
||||||
# or
|
# or
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Run Tests**:
|
2. **Run Tests**:
|
||||||
@@ -141,10 +141,10 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
|
|||||||
```bash
|
```bash
|
||||||
make docker-down
|
make docker-down
|
||||||
# or
|
# or
|
||||||
docker-compose down
|
podman compose down
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setup without Docker
|
### Setup without Podman
|
||||||
|
|
||||||
1. **Create Databases**:
|
1. **Create Databases**:
|
||||||
```sql
|
```sql
|
||||||
@@ -289,8 +289,8 @@ go test -tags=integration ./pkg/resolvespec -v
|
|||||||
**Problem**: "connection refused" or "database does not exist"
|
**Problem**: "connection refused" or "database does not exist"
|
||||||
|
|
||||||
**Solutions**:
|
**Solutions**:
|
||||||
1. Check PostgreSQL is running: `docker-compose ps`
|
1. Check PostgreSQL is running: `podman compose ps`
|
||||||
2. Verify databases exist: `docker-compose exec postgres-test psql -U postgres -l`
|
2. Verify databases exist: `podman compose exec postgres-test psql -U postgres -l`
|
||||||
3. Check environment variable: `echo $TEST_DATABASE_URL`
|
3. Check environment variable: `echo $TEST_DATABASE_URL`
|
||||||
4. Recreate databases: `make clean && make docker-up`
|
4. Recreate databases: `make clean && make docker-up`
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user