Compare commits

...

10 Commits

Author SHA1 Message Date
Hein
90df4a157c Socket spec tests 2025-12-23 17:27:48 +02:00
Hein
2dd404af96 Updated to websockspec 2025-12-23 17:27:29 +02:00
Hein
17c472b206 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into websocketspec 2025-12-23 15:23:36 +02:00
Hein
ed67caf055 fix: reasheadspec customsql calls AddTablePrefixToColumns
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m42s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m6s
Build , Vet Test, and Lint / Lint Code (push) Failing after -25m37s
Build , Vet Test, and Lint / Build (push) Successful in -25m35s
Tests / Unit Tests (push) Failing after -25m50s
Tests / Integration Tests (push) Failing after -25m59s
2025-12-23 14:17:02 +02:00
Hein
63ed62a9a3 fix: Stupid logic error.
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -25m39s
Build , Vet Test, and Lint / Build (push) Successful in -25m47s
Build , Vet Test, and Lint / Lint Code (push) Successful in -25m6s
Tests / Unit Tests (push) Failing after -26m5s
Tests / Integration Tests (push) Failing after -26m5s
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:52:34 +02:00
Hein
0525323a47 Fixed tests failing due to reponse header status
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:50:16 +02:00
Hein Puth (Warkanum)
c3443f702e Merge pull request #4 from bitechdev/fix-dockers
Fixed Attempt to Fix Docker / Podman
2025-12-19 16:42:38 +02:00
Hein
45c463c117 Fixed Attempt to Fix Docker / Podman
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:42:01 +02:00
Hein
84d673ce14 Added OpenAPI UI Routes
Co-authored-by: IvanX006 <ivan@bitechsystems.co.za>
Co-authored-by: Warkanum <HEIN.PUTH@GMAIL.COM>
Co-authored-by: Hein <hein@bitechsystems.co.za>
2025-12-19 16:32:14 +02:00
Hein
1b2b0d8f0b Prototype for websockspec 2025-12-12 16:14:47 +02:00
32 changed files with 8418 additions and 77 deletions

View File

@@ -16,7 +16,7 @@ test: test-unit test-integration
# Start PostgreSQL for integration tests # Start PostgreSQL for integration tests
docker-up: docker-up:
@echo "Starting PostgreSQL container..." @echo "Starting PostgreSQL container..."
@docker-compose up -d postgres-test @podman compose up -d postgres-test
@echo "Waiting for PostgreSQL to be ready..." @echo "Waiting for PostgreSQL to be ready..."
@sleep 5 @sleep 5
@echo "PostgreSQL is ready!" @echo "PostgreSQL is ready!"
@@ -24,12 +24,12 @@ docker-up:
# Stop PostgreSQL container # Stop PostgreSQL container
docker-down: docker-down:
@echo "Stopping PostgreSQL container..." @echo "Stopping PostgreSQL container..."
@docker-compose down @podman compose down
# Clean up Docker volumes and test data # Clean up Docker volumes and test data
clean: clean:
@echo "Cleaning up..." @echo "Cleaning up..."
@docker-compose down -v @podman compose down -v
@echo "Cleanup complete!" @echo "Cleanup complete!"
# Run integration tests with Docker (full workflow) # Run integration tests with Docker (full workflow)

2
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/glebarez/sqlite v1.11.0 github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.6.0 github.com/jackc/pgx/v5 v5.6.0
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1 github.com/redis/go-redis/v9 v9.17.1
@@ -101,6 +102,7 @@ require (
github.com/spf13/afero v1.15.0 // indirect github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect

2
go.sum
View File

@@ -85,6 +85,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=

View File

@@ -208,21 +208,9 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
} }
} }
} }
} else if tableName != "" && !hasTablePrefix(condToCheck) {
// If tableName is provided and the condition DOESN'T have a table prefix,
// qualify unambiguous column references to prevent "ambiguous column" errors
// when there are multiple joins on the same table (e.g., recursive preloads)
columnName := extractUnqualifiedColumnName(condToCheck)
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
// Qualify the column with the table name
// Be careful to only replace the column name, not other occurrences of the string
oldRef := columnName
newRef := tableName + "." + columnName
// Use word boundary matching to avoid replacing partial matches
cond = qualifyColumnInCondition(cond, oldRef, newRef)
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
}
} }
// Note: We no longer add prefixes to unqualified columns here.
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
validConditions = append(validConditions, cond) validConditions = append(validConditions, cond)
} }
@@ -633,3 +621,145 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
} }
return validColumns[strings.ToLower(columnName)] return validColumns[strings.ToLower(columnName)]
} }
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
// This function only prefixes simple column references and skips:
// - Columns already having a table prefix (containing a dot)
// - Columns inside function calls or expressions (inside parentheses)
// - Columns inside subqueries
// - Columns that don't exist in the table (validation via model registry)
//
// Examples:
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
// - "users.status = 'active'" -> unchanged (already has prefix)
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
//
// Parameters:
// - where: The WHERE clause to process
// - tableName: The table name to use as prefix
//
// Returns:
// - The WHERE clause with table prefixes added to appropriate and valid columns
func AddTablePrefixToColumns(where string, tableName string) string {
if where == "" || tableName == "" {
return where
}
where = strings.TrimSpace(where)
// Get valid columns from the model registry for validation
validColumns := getValidColumnsForTable(tableName)
// Split by AND to handle multiple conditions (parenthesis-aware)
conditions := splitByAND(where)
prefixedConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Process this condition to add table prefix if appropriate
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
prefixedConditions = append(prefixedConditions, processedCond)
}
if len(prefixedConditions) == 0 {
return ""
}
return strings.Join(prefixedConditions, " AND ")
}
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
// Returns the condition unchanged if:
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
// - The column reference is inside a function call
// - The column already has a table prefix
// - No valid column reference is found
// - The column doesn't exist in the table (when validColumns is provided)
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
// Strip outer grouping parentheses to get to the actual condition
strippedCond := stripOuterParentheses(cond)
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
return cond
}
// Extract the left side of the comparison (before the operator)
columnRef := extractLeftSideOfComparison(strippedCond)
if columnRef == "" {
return cond
}
// Skip if it already has a prefix (contains a dot)
if strings.Contains(columnRef, ".") {
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
return cond
}
// Skip if it's a function call or expression (contains parentheses)
if strings.Contains(columnRef, "(") {
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
return cond
}
// Validate that the column exists in the table (if we have column info)
if !isValidColumn(columnRef, validColumns) {
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
return cond
}
// It's a simple unqualified column reference that exists in the table - add the table prefix
newRef := tableName + "." + columnRef
result := qualifyColumnInCondition(cond, columnRef, newRef)
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
return result
}
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
// This is used to identify the column reference that may need a table prefix.
//
// Examples:
// - "status = 'active'" returns "status"
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
// - "priority > 5" returns "priority"
//
// Returns empty string if no operator is found.
func extractLeftSideOfComparison(cond string) string {
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
// Find the first operator outside of parentheses and quotes
minIdx := -1
for _, op := range operators {
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
leftSide := strings.TrimSpace(cond[:minIdx])
// Remove any surrounding quotes
leftSide = strings.Trim(leftSide, "`\"'")
return leftSide
}
// No operator found - might be a boolean column
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef := strings.Trim(parts[0], "`\"'")
// Make sure it's not a SQL keyword
if !IsSQLKeyword(strings.ToLower(columnRef)) {
return columnRef
}
}
return ""
}

View File

@@ -273,25 +273,151 @@ handler.SetOpenAPIGenerator(func() (string, error) {
}) })
``` ```
## Using with Swagger UI ## Using the Built-in UI Handler
You can serve the generated OpenAPI spec with Swagger UI: The package includes a built-in UI handler that serves popular OpenAPI visualization tools. No need to download or manage static files - everything is served from CDN.
### Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/openapi"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Setup your API routes and OpenAPI generator...
// (see examples above)
// Add the UI handler - defaults to Swagger UI
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
SpecURL: "/openapi",
Title: "My API Documentation",
})
// Now visit http://localhost:8080/docs
http.ListenAndServe(":8080", router)
}
```
### Supported UI Frameworks
The handler supports four popular OpenAPI UI frameworks:
#### 1. Swagger UI (Default)
The most widely used OpenAPI UI with excellent compatibility and features.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
Theme: "dark", // optional: "light" or "dark"
})
```
#### 2. RapiDoc
Modern, customizable, and feature-rich OpenAPI UI.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.RapiDoc,
Theme: "dark",
})
```
#### 3. Redoc
Clean, responsive documentation with great UX.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.Redoc,
})
```
#### 4. Scalar
Modern and sleek OpenAPI documentation.
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.Scalar,
Theme: "dark",
})
```
### Configuration Options
```go
type UIConfig struct {
UIType UIType // SwaggerUI, RapiDoc, Redoc, or Scalar
SpecURL string // URL to OpenAPI spec (default: "/openapi")
Title string // Page title (default: "API Documentation")
FaviconURL string // Custom favicon URL (optional)
CustomCSS string // Custom CSS to inject (optional)
Theme string // "light" or "dark" (support varies by UI)
}
```
### Custom Styling Example
```go
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
Title: "Acme Corp API",
CustomCSS: `
.swagger-ui .topbar {
background-color: #1976d2;
}
.swagger-ui .info .title {
color: #1976d2;
}
`,
})
```
### Using Multiple UIs
You can serve different UIs at different paths:
```go
// Swagger UI at /docs
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
UIType: openapi.SwaggerUI,
})
// Redoc at /redoc
openapi.SetupUIRoute(router, "/redoc", openapi.UIConfig{
UIType: openapi.Redoc,
})
// RapiDoc at /api-docs
openapi.SetupUIRoute(router, "/api-docs", openapi.UIConfig{
UIType: openapi.RapiDoc,
})
```
### Manual Handler Usage
If you need more control, use the handler directly:
```go
handler := openapi.UIHandler(openapi.UIConfig{
UIType: openapi.SwaggerUI,
SpecURL: "/api/openapi.json",
})
router.Handle("/documentation", handler)
```
## Using with External Swagger UI
Alternatively, you can use an external Swagger UI instance:
1. Get the spec from `/openapi` 1. Get the spec from `/openapi`
2. Load it in Swagger UI at `https://petstore.swagger.io/` 2. Load it in Swagger UI at `https://petstore.swagger.io/`
3. Or self-host Swagger UI and point it to your `/openapi` endpoint 3. Or self-host Swagger UI and point it to your `/openapi` endpoint
Example with self-hosted Swagger UI:
```go
// Serve Swagger UI static files
router.PathPrefix("/swagger/").Handler(
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
)
// Configure Swagger UI to use /openapi
```
## Testing ## Testing
You can test the OpenAPI endpoint: You can test the OpenAPI endpoint:

View File

@@ -183,6 +183,69 @@ func ExampleWithFuncSpec() {
_ = generatorFunc _ = generatorFunc
} }
// ExampleWithUIHandler shows how to serve OpenAPI documentation with a web UI
func ExampleWithUIHandler(db *gorm.DB) {
// Create handler and configure OpenAPI generator
handler := restheadspec.NewHandlerWithGORM(db)
registry := modelregistry.NewModelRegistry()
handler.SetOpenAPIGenerator(func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My API",
Description: "API documentation with interactive UI",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: registry,
IncludeRestheadSpec: true,
})
return generator.GenerateJSON()
})
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Add UI handlers for different frameworks
// Swagger UI at /docs (most popular)
SetupUIRoute(router, "/docs", UIConfig{
UIType: SwaggerUI,
SpecURL: "/openapi",
Title: "My API - Swagger UI",
Theme: "light",
})
// RapiDoc at /rapidoc (modern alternative)
SetupUIRoute(router, "/rapidoc", UIConfig{
UIType: RapiDoc,
SpecURL: "/openapi",
Title: "My API - RapiDoc",
})
// Redoc at /redoc (clean and responsive)
SetupUIRoute(router, "/redoc", UIConfig{
UIType: Redoc,
SpecURL: "/openapi",
Title: "My API - Redoc",
})
// Scalar at /scalar (modern and sleek)
SetupUIRoute(router, "/scalar", UIConfig{
UIType: Scalar,
SpecURL: "/openapi",
Title: "My API - Scalar",
Theme: "dark",
})
// Now you can access:
// http://localhost:8080/docs - Swagger UI
// http://localhost:8080/rapidoc - RapiDoc
// http://localhost:8080/redoc - Redoc
// http://localhost:8080/scalar - Scalar
// http://localhost:8080/openapi - Raw OpenAPI JSON
_ = router
}
// ExampleCustomization shows advanced customization options // ExampleCustomization shows advanced customization options
func ExampleCustomization() { func ExampleCustomization() {
// Create registry and register models with descriptions using struct tags // Create registry and register models with descriptions using struct tags

294
pkg/openapi/ui_handler.go Normal file
View File

@@ -0,0 +1,294 @@
package openapi
import (
"fmt"
"html/template"
"net/http"
"strings"
"github.com/gorilla/mux"
)
// UIType represents the type of OpenAPI UI to serve
type UIType string
const (
// SwaggerUI is the most popular OpenAPI UI
SwaggerUI UIType = "swagger-ui"
// RapiDoc is a modern, customizable OpenAPI UI
RapiDoc UIType = "rapidoc"
// Redoc is a clean, responsive OpenAPI UI
Redoc UIType = "redoc"
// Scalar is a modern and sleek OpenAPI UI
Scalar UIType = "scalar"
)
// UIConfig holds configuration for the OpenAPI UI handler
type UIConfig struct {
// UIType specifies which UI framework to use (default: SwaggerUI)
UIType UIType
// SpecURL is the URL to the OpenAPI spec JSON (default: "/openapi")
SpecURL string
// Title is the page title (default: "API Documentation")
Title string
// FaviconURL is the URL to the favicon (optional)
FaviconURL string
// CustomCSS allows injecting custom CSS (optional)
CustomCSS string
// Theme for the UI (light/dark, depends on UI type)
Theme string
}
// UIHandler creates an HTTP handler that serves an OpenAPI UI
func UIHandler(config UIConfig) http.HandlerFunc {
// Set defaults
if config.UIType == "" {
config.UIType = SwaggerUI
}
if config.SpecURL == "" {
config.SpecURL = "/openapi"
}
if config.Title == "" {
config.Title = "API Documentation"
}
if config.Theme == "" {
config.Theme = "light"
}
return func(w http.ResponseWriter, r *http.Request) {
var htmlContent string
var err error
switch config.UIType {
case SwaggerUI:
htmlContent, err = generateSwaggerUI(config)
case RapiDoc:
htmlContent, err = generateRapiDoc(config)
case Redoc:
htmlContent, err = generateRedoc(config)
case Scalar:
htmlContent, err = generateScalar(config)
default:
http.Error(w, "Unsupported UI type", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, fmt.Sprintf("Failed to generate UI: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte(htmlContent))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to write response: %v", err), http.StatusInternalServerError)
return
}
}
}
// templateData wraps UIConfig to properly handle CSS in templates
type templateData struct {
UIConfig
SafeCustomCSS template.CSS
}
// generateSwaggerUI generates the HTML for Swagger UI
func generateSwaggerUI(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; }
*, *:before, *:after { box-sizing: inherit; }
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-standalone-preset.js"></script>
<script>
window.onload = function() {
const ui = SwaggerUIBundle({
url: "{{.SpecURL}}",
dom_id: '#swagger-ui',
deepLinking: true,
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
layout: "StandaloneLayout",
{{if eq .Theme "dark"}}
syntaxHighlight: {
activate: true,
theme: "monokai"
}
{{end}}
});
window.ui = ui;
};
</script>
</body>
</html>`
t, err := template.New("swagger").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateRapiDoc generates the HTML for RapiDoc
func generateRapiDoc(config UIConfig) (string, error) {
theme := "light"
if config.Theme == "dark" {
theme = "dark"
}
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
</head>
<body>
<rapi-doc
spec-url="{{.SpecURL}}"
theme="` + theme + `"
render-style="read"
show-header="true"
show-info="true"
allow-try="true"
allow-server-selection="true"
allow-authentication="true"
api-key-name="Authorization"
api-key-location="header"
></rapi-doc>
</body>
</html>`
t, err := template.New("rapidoc").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateRedoc generates the HTML for Redoc
func generateRedoc(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<redoc spec-url="{{.SpecURL}}" {{if eq .Theme "dark"}}theme='{"colors": {"primary": {"main": "#dd5522"}}}'{{end}}></redoc>
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
</body>
</html>`
t, err := template.New("redoc").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// generateScalar generates the HTML for Scalar
func generateScalar(config UIConfig) (string, error) {
tmpl := `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<script id="api-reference" data-url="{{.SpecURL}}" {{if eq .Theme "dark"}}data-theme="dark"{{end}}></script>
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
</body>
</html>`
t, err := template.New("scalar").Parse(tmpl)
if err != nil {
return "", err
}
data := templateData{
UIConfig: config,
SafeCustomCSS: template.CSS(config.CustomCSS),
}
var buf strings.Builder
if err := t.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
// SetupUIRoute adds the OpenAPI UI route to a mux router
// This is a convenience function for the most common use case
func SetupUIRoute(router *mux.Router, path string, config UIConfig) {
router.Handle(path, UIHandler(config))
}

View File

@@ -0,0 +1,308 @@
package openapi
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
)
func TestUIHandler_SwaggerUI(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
SpecURL: "/openapi",
Title: "Test API Docs",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Swagger UI specific content
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected Swagger UI content")
}
if !strings.Contains(body, "SwaggerUIBundle") {
t.Error("Expected SwaggerUIBundle script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
if !strings.Contains(body, "swagger-ui-dist") {
t.Error("Expected Swagger UI CDN link")
}
}
func TestUIHandler_RapiDoc(t *testing.T) {
config := UIConfig{
UIType: RapiDoc,
SpecURL: "/api/spec",
Title: "RapiDoc Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for RapiDoc specific content
if !strings.Contains(body, "rapi-doc") {
t.Error("Expected rapi-doc element")
}
if !strings.Contains(body, "rapidoc-min.js") {
t.Error("Expected RapiDoc script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_Redoc(t *testing.T) {
config := UIConfig{
UIType: Redoc,
SpecURL: "/spec.json",
Title: "Redoc Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Redoc specific content
if !strings.Contains(body, "<redoc") {
t.Error("Expected redoc element")
}
if !strings.Contains(body, "redoc.standalone.js") {
t.Error("Expected Redoc script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_Scalar(t *testing.T) {
config := UIConfig{
UIType: Scalar,
SpecURL: "/openapi.json",
Title: "Scalar Test",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Check for Scalar specific content
if !strings.Contains(body, "api-reference") {
t.Error("Expected api-reference element")
}
if !strings.Contains(body, "@scalar/api-reference") {
t.Error("Expected Scalar script")
}
if !strings.Contains(body, config.Title) {
t.Errorf("Expected title '%s' in HTML", config.Title)
}
if !strings.Contains(body, config.SpecURL) {
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
}
}
func TestUIHandler_DefaultValues(t *testing.T) {
// Test with empty config to check defaults
config := UIConfig{}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body := w.Body.String()
// Should default to Swagger UI
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected default to Swagger UI")
}
// Should default to /openapi spec URL
if !strings.Contains(body, "/openapi") {
t.Error("Expected default spec URL '/openapi'")
}
// Should default to "API Documentation" title
if !strings.Contains(body, "API Documentation") {
t.Error("Expected default title 'API Documentation'")
}
}
func TestUIHandler_CustomCSS(t *testing.T) {
customCSS := ".custom-class { color: red; }"
config := UIConfig{
UIType: SwaggerUI,
CustomCSS: customCSS,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
if !strings.Contains(body, customCSS) {
t.Errorf("Expected custom CSS to be included. Body:\n%s", body)
}
}
func TestUIHandler_Favicon(t *testing.T) {
faviconURL := "https://example.com/favicon.ico"
config := UIConfig{
UIType: SwaggerUI,
FaviconURL: faviconURL,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
if !strings.Contains(body, faviconURL) {
t.Error("Expected favicon URL to be included")
}
}
func TestUIHandler_DarkTheme(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
Theme: "dark",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
body := w.Body.String()
// SwaggerUI uses monokai theme for dark mode
if !strings.Contains(body, "monokai") {
t.Error("Expected dark theme configuration for Swagger UI")
}
}
func TestUIHandler_InvalidUIType(t *testing.T) {
config := UIConfig{
UIType: "invalid-ui-type",
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid UI type, got %d", resp.StatusCode)
}
}
func TestUIHandler_ContentType(t *testing.T) {
config := UIConfig{
UIType: SwaggerUI,
}
handler := UIHandler(config)
req := httptest.NewRequest("GET", "/docs", nil)
w := httptest.NewRecorder()
handler(w, req)
contentType := w.Header().Get("Content-Type")
if !strings.Contains(contentType, "text/html") {
t.Errorf("Expected Content-Type to contain 'text/html', got '%s'", contentType)
}
if !strings.Contains(contentType, "charset=utf-8") {
t.Errorf("Expected Content-Type to contain 'charset=utf-8', got '%s'", contentType)
}
}
func TestSetupUIRoute(t *testing.T) {
router := mux.NewRouter()
config := UIConfig{
UIType: SwaggerUI,
}
SetupUIRoute(router, "/api-docs", config)
// Test that the route was added and works
req := httptest.NewRequest("GET", "/api-docs", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Verify it returns HTML
body := w.Body.String()
if !strings.Contains(body, "swagger-ui") {
t.Error("Expected Swagger UI content")
}
}

View File

@@ -482,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition) // Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" { if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables // First add table prefixes to unqualified columns (but skip columns inside function calls)
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedWhere != "" { if sanitizedWhere != "" {
query = query.Where(sanitizedWhere) query = query.Where(sanitizedWhere)
} }
@@ -492,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (OR condition) // Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" { if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedOr != "" { if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr) query = query.WhereOr(sanitizedOr)
} }

View File

@@ -1,3 +1,4 @@
//go:build integration
// +build integration // +build integration
package restheadspec package restheadspec
@@ -21,12 +22,12 @@ import (
// Test models // Test models
type TestUser struct { type TestUser struct {
ID uint `gorm:"primaryKey" json:"id"` ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"` Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"` Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"` Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"` Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
} }
@@ -35,13 +36,13 @@ func (TestUser) TableName() string {
} }
type TestPost struct { type TestPost struct {
ID uint `gorm:"primaryKey" json:"id"` ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"` UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"` Title string `gorm:"not null" json:"title"`
Content string `json:"content"` Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"` Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"` User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"` Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
} }
@@ -54,7 +55,7 @@ type TestComment struct {
PostID uint `gorm:"not null" json:"post_id"` PostID uint `gorm:"not null" json:"post_id"`
Content string `gorm:"not null" json:"content"` Content string `gorm:"not null" json:"content"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"` Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
} }
func (TestComment) TableName() string { func (TestComment) TableName() string {
@@ -401,7 +402,7 @@ func TestIntegration_GetMetadata(t *testing.T) {
muxRouter.ServeHTTP(w, req) muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK { if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String()) t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
} }
@@ -492,7 +493,7 @@ func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
muxRouter.ServeHTTP(w, req) muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK { if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
t.Errorf("Expected status 200, got %d", w.Code) t.Errorf("Expected status 200, got %d", w.Code)
} }

View File

@@ -465,7 +465,7 @@ func processRequest(ctx context.Context) {
1. **Check collector is running:** 1. **Check collector is running:**
```bash ```bash
docker-compose ps podman compose ps
``` ```
2. **Verify endpoint:** 2. **Verify endpoint:**
@@ -476,7 +476,7 @@ func processRequest(ctx context.Context) {
3. **Check logs:** 3. **Check logs:**
```bash ```bash
docker-compose logs otel-collector podman compose logs otel-collector
``` ```
### Disable Tracing ### Disable Tracing

726
pkg/websocketspec/README.md Normal file
View File

@@ -0,0 +1,726 @@
# WebSocketSpec - Real-Time WebSocket API Framework
WebSocketSpec provides a WebSocket-based API specification for real-time, bidirectional communication with full CRUD operations, subscriptions, and lifecycle hooks.
## Table of Contents
- [Features](#features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Message Protocol](#message-protocol)
- [CRUD Operations](#crud-operations)
- [Subscriptions](#subscriptions)
- [Lifecycle Hooks](#lifecycle-hooks)
- [Client Examples](#client-examples)
- [Authentication](#authentication)
- [Error Handling](#error-handling)
- [Best Practices](#best-practices)
## Features
- **Real-Time Bidirectional Communication**: WebSocket-based persistent connections
- **Full CRUD Operations**: Create, Read, Update, Delete with rich query options
- **Real-Time Subscriptions**: Subscribe to entity changes with filter support
- **Automatic Notifications**: Server pushes updates to subscribed clients
- **Lifecycle Hooks**: Before/after hooks for all operations
- **Database Agnostic**: Works with GORM and Bun ORM through adapters
- **Connection Management**: Automatic connection tracking and cleanup
- **Request/Response Correlation**: Message IDs for tracking requests
- **Filter & Sort**: Advanced filtering, sorting, pagination, and preloading
## Installation
```bash
go get github.com/bitechdev/ResolveSpec
```
## Quick Start
### Server Setup
```go
package main
import (
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func main() {
// Connect to database
db, _ := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
// Create WebSocket handler
handler := websocketspec.NewHandlerWithGORM(db)
// Register models
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup WebSocket endpoint
http.HandleFunc("/ws", handler.HandleWebSocket)
// Start server
http.ListenAndServe(":8080", nil)
}
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}
type Post struct {
ID uint `json:"id" gorm:"primaryKey"`
Title string `json:"title"`
Content string `json:"content"`
UserID uint `json:"user_id"`
}
```
### Client Setup (JavaScript)
```javascript
const ws = new WebSocket("ws://localhost:8080/ws");
ws.onopen = () => {
console.log("Connected to WebSocket");
};
ws.onmessage = (event) => {
const message = JSON.parse(event.data);
console.log("Received:", message);
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
};
```
## Message Protocol
All messages are JSON-encoded with the following structure:
```typescript
interface Message {
id: string; // Unique message ID for correlation
type: "request" | "response" | "notification" | "subscription";
operation?: "read" | "create" | "update" | "delete" | "subscribe" | "unsubscribe" | "meta";
schema?: string; // Database schema
entity: string; // Table/model name
record_id?: string; // For single-record operations
data?: any; // Request/response payload
options?: QueryOptions; // Filters, sorting, pagination
subscription_id?: string; // For subscription messages
success?: boolean; // Response success indicator
error?: ErrorInfo; // Error details
metadata?: Record<string, any>; // Additional metadata
timestamp?: string; // Message timestamp
}
interface QueryOptions {
filters?: FilterOption[];
columns?: string[];
preload?: PreloadOption[];
sort?: SortOption[];
limit?: number;
offset?: number;
}
```
## CRUD Operations
### CREATE - Create New Records
**Request:**
```json
{
"id": "msg-1",
"type": "request",
"operation": "create",
"schema": "public",
"entity": "users",
"data": {
"name": "John Doe",
"email": "john@example.com",
"status": "active"
}
}
```
**Response:**
```json
{
"id": "msg-1",
"type": "response",
"success": true,
"data": {
"id": 123,
"name": "John Doe",
"email": "john@example.com",
"status": "active"
},
"timestamp": "2025-12-12T10:30:00Z"
}
```
### READ - Query Records
**Read Multiple Records:**
```json
{
"id": "msg-2",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
],
"columns": ["id", "name", "email"],
"sort": [
{"column": "name", "direction": "asc"}
],
"limit": 10,
"offset": 0
}
}
```
**Read Single Record:**
```json
{
"id": "msg-3",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"record_id": "123"
}
```
**Response:**
```json
{
"id": "msg-2",
"type": "response",
"success": true,
"data": [
{"id": 1, "name": "Alice", "email": "alice@example.com"},
{"id": 2, "name": "Bob", "email": "bob@example.com"}
],
"metadata": {
"total": 50,
"count": 2
},
"timestamp": "2025-12-12T10:30:00Z"
}
```
### UPDATE - Update Records
```json
{
"id": "msg-4",
"type": "request",
"operation": "update",
"schema": "public",
"entity": "users",
"record_id": "123",
"data": {
"name": "John Updated",
"email": "john.updated@example.com"
}
}
```
### DELETE - Delete Records
```json
{
"id": "msg-5",
"type": "request",
"operation": "delete",
"schema": "public",
"entity": "users",
"record_id": "123"
}
```
## Subscriptions
Subscriptions allow clients to receive real-time notifications when entities change.
### Subscribe to Changes
```json
{
"id": "sub-1",
"type": "subscription",
"operation": "subscribe",
"schema": "public",
"entity": "users",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
]
}
}
```
**Response:**
```json
{
"id": "sub-1",
"type": "response",
"success": true,
"data": {
"subscription_id": "sub-abc123",
"schema": "public",
"entity": "users"
},
"timestamp": "2025-12-12T10:30:00Z"
}
```
### Receive Notifications
When a subscribed entity changes, clients automatically receive notifications:
```json
{
"type": "notification",
"operation": "create",
"subscription_id": "sub-abc123",
"schema": "public",
"entity": "users",
"data": {
"id": 124,
"name": "Jane Smith",
"email": "jane@example.com",
"status": "active"
},
"timestamp": "2025-12-12T10:35:00Z"
}
```
**Notification Operations:**
- `create` - New record created
- `update` - Record updated
- `delete` - Record deleted
### Unsubscribe
```json
{
"id": "unsub-1",
"type": "subscription",
"operation": "unsubscribe",
"subscription_id": "sub-abc123"
}
```
## Lifecycle Hooks
Hooks allow you to intercept and modify operations at various points in the lifecycle.
### Available Hook Types
- **BeforeRead** / **AfterRead**
- **BeforeCreate** / **AfterCreate**
- **BeforeUpdate** / **AfterUpdate**
- **BeforeDelete** / **AfterDelete**
- **BeforeSubscribe** / **AfterSubscribe**
- **BeforeConnect** / **AfterConnect**
### Hook Example
```go
handler := websocketspec.NewHandlerWithGORM(db)
// Authorization hook
handler.Hooks().RegisterBefore(websocketspec.OperationRead, func(ctx *websocketspec.HookContext) error {
// Check permissions
userID, _ := ctx.Connection.GetMetadata("user_id")
if userID == nil {
return fmt.Errorf("unauthorized: user not authenticated")
}
// Add filter to only show user's own records
if ctx.Entity == "posts" {
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "user_id",
Operator: "eq",
Value: userID,
})
}
return nil
})
// Logging hook
handler.Hooks().RegisterAfter(websocketspec.OperationCreate, func(ctx *websocketspec.HookContext) error {
log.Printf("Created %s in %s.%s", ctx.Result, ctx.Schema, ctx.Entity)
return nil
})
// Validation hook
handler.Hooks().RegisterBefore(websocketspec.OperationCreate, func(ctx *websocketspec.HookContext) error {
// Validate data before creation
if data, ok := ctx.Data.(map[string]interface{}); ok {
if email, exists := data["email"]; !exists || email == "" {
return fmt.Errorf("email is required")
}
}
return nil
})
```
## Client Examples
### JavaScript/TypeScript Client
```typescript
class WebSocketClient {
private ws: WebSocket;
private messageHandlers: Map<string, (data: any) => void> = new Map();
private subscriptions: Map<string, (data: any) => void> = new Map();
constructor(url: string) {
this.ws = new WebSocket(url);
this.ws.onmessage = (event) => this.handleMessage(event);
}
// Send request and wait for response
async request(operation: string, entity: string, options?: any): Promise<any> {
const id = this.generateId();
return new Promise((resolve, reject) => {
this.messageHandlers.set(id, (data) => {
if (data.success) {
resolve(data.data);
} else {
reject(data.error);
}
});
this.ws.send(JSON.stringify({
id,
type: "request",
operation,
entity,
...options
}));
});
}
// Subscribe to entity changes
async subscribe(entity: string, filters?: any[], callback?: (data: any) => void): Promise<string> {
const id = this.generateId();
return new Promise((resolve, reject) => {
this.messageHandlers.set(id, (data) => {
if (data.success) {
const subId = data.data.subscription_id;
if (callback) {
this.subscriptions.set(subId, callback);
}
resolve(subId);
} else {
reject(data.error);
}
});
this.ws.send(JSON.stringify({
id,
type: "subscription",
operation: "subscribe",
entity,
options: { filters }
}));
});
}
private handleMessage(event: MessageEvent) {
const message = JSON.parse(event.data);
if (message.type === "response") {
const handler = this.messageHandlers.get(message.id);
if (handler) {
handler(message);
this.messageHandlers.delete(message.id);
}
} else if (message.type === "notification") {
const callback = this.subscriptions.get(message.subscription_id);
if (callback) {
callback(message);
}
}
}
private generateId(): string {
return `msg-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
}
}
// Usage
const client = new WebSocketClient("ws://localhost:8080/ws");
// Read users
const users = await client.request("read", "users", {
options: {
filters: [{ column: "status", operator: "eq", value: "active" }],
limit: 10
}
});
// Subscribe to user changes
await client.subscribe("users",
[{ column: "status", operator: "eq", value: "active" }],
(notification) => {
console.log("User changed:", notification.operation, notification.data);
}
);
// Create user
const newUser = await client.request("create", "users", {
data: {
name: "Alice",
email: "alice@example.com",
status: "active"
}
});
```
### Python Client Example
```python
import asyncio
import websockets
import json
import uuid
class WebSocketClient:
def __init__(self, url):
self.url = url
self.ws = None
self.handlers = {}
self.subscriptions = {}
async def connect(self):
self.ws = await websockets.connect(self.url)
asyncio.create_task(self.listen())
async def listen(self):
async for message in self.ws:
data = json.loads(message)
if data["type"] == "response":
handler = self.handlers.get(data["id"])
if handler:
handler(data)
del self.handlers[data["id"]]
elif data["type"] == "notification":
callback = self.subscriptions.get(data["subscription_id"])
if callback:
callback(data)
async def request(self, operation, entity, **kwargs):
msg_id = str(uuid.uuid4())
future = asyncio.Future()
self.handlers[msg_id] = lambda data: future.set_result(data)
await self.ws.send(json.dumps({
"id": msg_id,
"type": "request",
"operation": operation,
"entity": entity,
**kwargs
}))
result = await future
if result["success"]:
return result["data"]
else:
raise Exception(result["error"]["message"])
async def subscribe(self, entity, callback, filters=None):
msg_id = str(uuid.uuid4())
future = asyncio.Future()
self.handlers[msg_id] = lambda data: future.set_result(data)
await self.ws.send(json.dumps({
"id": msg_id,
"type": "subscription",
"operation": "subscribe",
"entity": entity,
"options": {"filters": filters} if filters else {}
}))
result = await future
if result["success"]:
sub_id = result["data"]["subscription_id"]
self.subscriptions[sub_id] = callback
return sub_id
else:
raise Exception(result["error"]["message"])
# Usage
async def main():
client = WebSocketClient("ws://localhost:8080/ws")
await client.connect()
# Read users
users = await client.request("read", "users",
options={
"filters": [{"column": "status", "operator": "eq", "value": "active"}],
"limit": 10
}
)
print("Users:", users)
# Subscribe to changes
def on_user_change(notification):
print(f"User {notification['operation']}: {notification['data']}")
await client.subscribe("users", on_user_change,
filters=[{"column": "status", "operator": "eq", "value": "active"}]
)
asyncio.run(main())
```
## Authentication
Implement authentication using hooks:
```go
handler := websocketspec.NewHandlerWithGORM(db)
// Authentication on connection
handler.Hooks().Register(websocketspec.BeforeConnect, func(ctx *websocketspec.HookContext) error {
// Extract token from query params or headers
r := ctx.Connection.ws.UnderlyingConn().RemoteAddr()
// Validate token (implement your auth logic)
token := extractToken(r)
user, err := validateToken(token)
if err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
// Store user info in connection metadata
ctx.Connection.SetMetadata("user", user)
ctx.Connection.SetMetadata("user_id", user.ID)
return nil
})
// Check permissions for each operation
handler.Hooks().RegisterBefore(websocketspec.OperationRead, func(ctx *websocketspec.HookContext) error {
userID, ok := ctx.Connection.GetMetadata("user_id")
if !ok {
return fmt.Errorf("unauthorized")
}
// Add user-specific filters
if ctx.Entity == "orders" {
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "user_id",
Operator: "eq",
Value: userID,
})
}
return nil
})
```
## Error Handling
Errors are returned in a consistent format:
```json
{
"id": "msg-1",
"type": "response",
"success": false,
"error": {
"code": "validation_error",
"message": "Email is required",
"details": {
"field": "email"
}
},
"timestamp": "2025-12-12T10:30:00Z"
}
```
**Common Error Codes:**
- `invalid_message` - Message format is invalid
- `model_not_found` - Entity not registered
- `invalid_model` - Model validation failed
- `read_error` - Read operation failed
- `create_error` - Create operation failed
- `update_error` - Update operation failed
- `delete_error` - Delete operation failed
- `hook_error` - Hook execution failed
- `unauthorized` - Authentication/authorization failed
## Best Practices
1. **Always Use Message IDs**: Correlate requests with responses using unique IDs
2. **Handle Reconnections**: Implement automatic reconnection logic on the client
3. **Validate Data**: Use before-hooks to validate data before operations
4. **Limit Subscriptions**: Implement limits on subscriptions per connection
5. **Use Filters**: Apply filters to subscriptions to reduce unnecessary notifications
6. **Implement Authentication**: Always validate users before processing operations
7. **Handle Errors Gracefully**: Display user-friendly error messages
8. **Clean Up**: Unsubscribe when components unmount or disconnect
9. **Rate Limiting**: Implement rate limiting to prevent abuse
10. **Monitor Connections**: Track active connections and subscriptions
## Filter Operators
Supported filter operators:
- `eq` - Equal (=)
- `neq` - Not Equal (!=)
- `gt` - Greater Than (>)
- `gte` - Greater Than or Equal (>=)
- `lt` - Less Than (<)
- `lte` - Less Than or Equal (<=)
- `like` - LIKE (case-sensitive)
- `ilike` - ILIKE (case-insensitive)
- `in` - IN (array of values)
## Performance Considerations
- **Connection Pooling**: WebSocket connections are reused, reducing overhead
- **Subscription Filtering**: Only matching updates are sent to clients
- **Efficient Queries**: Uses database adapters for optimized queries
- **Message Batching**: Multiple messages can be sent in one write
- **Keepalive**: Automatic ping/pong for connection health
## Comparison with Other Specs
| Feature | WebSocketSpec | RestHeadSpec | ResolveSpec |
|---------|--------------|--------------|-------------|
| Protocol | WebSocket | HTTP/REST | HTTP/REST |
| Real-time | ✅ Yes | ❌ No | ❌ No |
| Subscriptions | ✅ Yes | ❌ No | ❌ No |
| Bidirectional | ✅ Yes | ❌ No | ❌ No |
| Query Options | In Message | In Headers | In Body |
| Overhead | Low | Medium | Medium |
| Use Case | Real-time apps | Traditional APIs | Body-based APIs |
## License
MIT License - See LICENSE file for details

View File

@@ -0,0 +1,370 @@
package websocketspec
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Connection rvepresents a WebSocket connection with its state
type Connection struct {
// ID is a unique identifier for this connection
ID string
// ws is the underlying WebSocket connection
ws *websocket.Conn
// send is a channel for outbound messages
send chan []byte
// subscriptions holds active subscriptions for this connection
subscriptions map[string]*Subscription
// mu protects subscriptions map
mu sync.RWMutex
// ctx is the connection context
ctx context.Context
// cancel cancels the connection context
cancel context.CancelFunc
// handler is the WebSocket handler
handler *Handler
// metadata stores connection-specific metadata (e.g., user info, auth state)
metadata map[string]interface{}
// metaMu protects metadata map
metaMu sync.RWMutex
// closedOnce ensures cleanup happens only once
closedOnce sync.Once
}
// ConnectionManager manages all active WebSocket connections
type ConnectionManager struct {
// connections holds all active connections
connections map[string]*Connection
// mu protects the connections map
mu sync.RWMutex
// register channel for new connections
register chan *Connection
// unregister channel for closing connections
unregister chan *Connection
// broadcast channel for broadcasting messages
broadcast chan *BroadcastMessage
// ctx is the manager context
ctx context.Context
// cancel cancels the manager context
cancel context.CancelFunc
}
// BroadcastMessage represents a message to broadcast to multiple connections
type BroadcastMessage struct {
// Message is the message to broadcast
Message []byte
// Filter is an optional function to filter which connections receive the message
Filter func(*Connection) bool
}
// NewConnection creates a new WebSocket connection
func NewConnection(id string, ws *websocket.Conn, handler *Handler) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
ID: id,
ws: ws,
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: ctx,
cancel: cancel,
handler: handler,
metadata: make(map[string]interface{}),
}
}
// NewConnectionManager creates a new connection manager
func NewConnectionManager(ctx context.Context) *ConnectionManager {
ctx, cancel := context.WithCancel(ctx)
return &ConnectionManager{
connections: make(map[string]*Connection),
register: make(chan *Connection),
unregister: make(chan *Connection),
broadcast: make(chan *BroadcastMessage),
ctx: ctx,
cancel: cancel,
}
}
// Run starts the connection manager event loop
func (cm *ConnectionManager) Run() {
for {
select {
case conn := <-cm.register:
cm.mu.Lock()
cm.connections[conn.ID] = conn
cm.mu.Unlock()
logger.Info("[WebSocketSpec] Connection registered: %s (total: %d)", conn.ID, cm.Count())
case conn := <-cm.unregister:
cm.mu.Lock()
if _, ok := cm.connections[conn.ID]; ok {
delete(cm.connections, conn.ID)
close(conn.send)
logger.Info("[WebSocketSpec] Connection unregistered: %s (total: %d)", conn.ID, cm.Count())
}
cm.mu.Unlock()
case msg := <-cm.broadcast:
cm.mu.RLock()
for _, conn := range cm.connections {
if msg.Filter == nil || msg.Filter(conn) {
select {
case conn.send <- msg.Message:
default:
// Channel full, connection is slow - close it
logger.Warn("[WebSocketSpec] Connection %s send buffer full, closing", conn.ID)
cm.mu.RUnlock()
cm.unregister <- conn
cm.mu.RLock()
}
}
}
cm.mu.RUnlock()
case <-cm.ctx.Done():
logger.Info("[WebSocketSpec] Connection manager shutting down")
return
}
}
}
// Register registers a new connection
func (cm *ConnectionManager) Register(conn *Connection) {
cm.register <- conn
}
// Unregister removes a connection
func (cm *ConnectionManager) Unregister(conn *Connection) {
cm.unregister <- conn
}
// Broadcast sends a message to all connections matching the filter
func (cm *ConnectionManager) Broadcast(message []byte, filter func(*Connection) bool) {
cm.broadcast <- &BroadcastMessage{
Message: message,
Filter: filter,
}
}
// Count returns the number of active connections
func (cm *ConnectionManager) Count() int {
cm.mu.RLock()
defer cm.mu.RUnlock()
return len(cm.connections)
}
// GetConnection retrieves a connection by ID
func (cm *ConnectionManager) GetConnection(id string) (*Connection, bool) {
cm.mu.RLock()
defer cm.mu.RUnlock()
conn, ok := cm.connections[id]
return conn, ok
}
// Shutdown gracefully shuts down the connection manager
func (cm *ConnectionManager) Shutdown() {
cm.cancel()
// Close all connections
cm.mu.Lock()
for _, conn := range cm.connections {
conn.Close()
}
cm.mu.Unlock()
}
// ReadPump reads messages from the WebSocket connection
func (c *Connection) ReadPump() {
defer func() {
c.handler.connManager.Unregister(c)
c.Close()
}()
// Configure read parameters
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
c.ws.SetPongHandler(func(string) error {
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.ws.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logger.Error("[WebSocketSpec] Connection %s read error: %v", c.ID, err)
}
break
}
// Parse and handle the message
c.handleMessage(message)
}
}
// WritePump writes messages to the WebSocket connection
func (c *Connection) WritePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.Close()
}()
for {
select {
case message, ok := <-c.send:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
// Channel closed
c.ws.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.ws.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// Write any queued messages
n := len(c.send)
for i := 0; i < n; i++ {
w.Write([]byte{'\n'})
w.Write(<-c.send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
case <-c.ctx.Done():
return
}
}
}
// Send sends a message to this connection
func (c *Connection) Send(message []byte) error {
select {
case c.send <- message:
return nil
case <-c.ctx.Done():
return fmt.Errorf("connection closed")
default:
return fmt.Errorf("send buffer full")
}
}
// SendJSON sends a JSON-encoded message to this connection
func (c *Connection) SendJSON(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
return c.Send(data)
}
// Close closes the connection
func (c *Connection) Close() {
c.closedOnce.Do(func() {
c.cancel()
c.ws.Close()
// Clean up subscriptions
c.mu.Lock()
for subID := range c.subscriptions {
c.handler.subscriptionManager.Unsubscribe(subID)
}
c.subscriptions = make(map[string]*Subscription)
c.mu.Unlock()
logger.Info("[WebSocketSpec] Connection %s closed", c.ID)
})
}
// AddSubscription adds a subscription to this connection
func (c *Connection) AddSubscription(sub *Subscription) {
c.mu.Lock()
defer c.mu.Unlock()
c.subscriptions[sub.ID] = sub
}
// RemoveSubscription removes a subscription from this connection
func (c *Connection) RemoveSubscription(subID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.subscriptions, subID)
}
// GetSubscription retrieves a subscription by ID
func (c *Connection) GetSubscription(subID string) (*Subscription, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
sub, ok := c.subscriptions[subID]
return sub, ok
}
// SetMetadata sets metadata for this connection
func (c *Connection) SetMetadata(key string, value interface{}) {
c.metaMu.Lock()
defer c.metaMu.Unlock()
c.metadata[key] = value
}
// GetMetadata retrieves metadata for this connection
func (c *Connection) GetMetadata(key string) (interface{}, bool) {
c.metaMu.RLock()
defer c.metaMu.RUnlock()
val, ok := c.metadata[key]
return val, ok
}
// handleMessage processes an incoming message
func (c *Connection) handleMessage(data []byte) {
msg, err := ParseMessage(data)
if err != nil {
logger.Error("[WebSocketSpec] Failed to parse message: %v", err)
errResp := NewErrorResponse("", "invalid_message", "Failed to parse message")
c.SendJSON(errResp)
return
}
if !msg.IsValid() {
logger.Error("[WebSocketSpec] Invalid message received")
errResp := NewErrorResponse(msg.ID, "invalid_message", "Message validation failed")
c.SendJSON(errResp)
return
}
// Route message to appropriate handler
c.handler.HandleMessage(c, msg)
}

View File

@@ -0,0 +1,596 @@
package websocketspec
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper function to create a test connection with proper initialization
func createTestConnection(id string) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
ID: id,
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: make(map[string]interface{}),
ctx: ctx,
cancel: cancel,
}
}
func TestNewConnectionManager(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
assert.NotNil(t, cm)
assert.NotNil(t, cm.connections)
assert.NotNil(t, cm.register)
assert.NotNil(t, cm.unregister)
assert.NotNil(t, cm.broadcast)
assert.Equal(t, 0, cm.Count())
}
func TestConnectionManager_Count(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer func() {
// Cancel context without calling Shutdown which tries to close connections
cm.cancel()
}()
// Initially empty
assert.Equal(t, 0, cm.Count())
// Add a connection
conn := createTestConnection("conn-1")
cm.Register(conn)
time.Sleep(10 * time.Millisecond) // Give time for registration
assert.Equal(t, 1, cm.Count())
}
func TestConnectionManager_Register(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
conn := createTestConnection("conn-1")
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
// Verify connection was registered
retrievedConn, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
}
func TestConnectionManager_Unregister(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
}
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
cm.Unregister(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 0, cm.Count())
// Verify connection was removed
_, exists := cm.GetConnection("conn-1")
assert.False(t, exists)
}
func TestConnectionManager_GetConnection(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Non-existent connection
_, exists := cm.GetConnection("non-existent")
assert.False(t, exists)
// Register connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
}
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
// Get existing connection
retrievedConn, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
}
func TestConnectionManager_MultipleConnections(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register multiple connections
conn1 := &Connection{ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn2 := &Connection{ID: "conn-2", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn3 := &Connection{ID: "conn-3", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
cm.Register(conn1)
cm.Register(conn2)
cm.Register(conn3)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 3, cm.Count())
// Verify all connections exist
_, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-2")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-3")
assert.True(t, exists)
// Unregister one
cm.Unregister(conn2)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 2, cm.Count())
// Verify conn-2 is gone but others remain
_, exists = cm.GetConnection("conn-2")
assert.False(t, exists)
_, exists = cm.GetConnection("conn-1")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-3")
assert.True(t, exists)
}
func TestConnectionManager_Shutdown(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
// Register connections
conn1 := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
}
conn1.ctx, conn1.cancel = context.WithCancel(context.Background())
cm.Register(conn1)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
// Shutdown
cm.Shutdown()
time.Sleep(10 * time.Millisecond)
// Verify context was cancelled
select {
case <-cm.ctx.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Fatal("Context not cancelled after shutdown")
}
}
func TestConnection_SetMetadata(t *testing.T) {
conn := &Connection{
metadata: make(map[string]interface{}),
}
conn.SetMetadata("user_id", 123)
conn.SetMetadata("username", "john")
// Verify metadata was set
userID, exists := conn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
username, exists := conn.GetMetadata("username")
assert.True(t, exists)
assert.Equal(t, "john", username)
}
func TestConnection_GetMetadata(t *testing.T) {
conn := &Connection{
metadata: map[string]interface{}{
"user_id": 123,
"role": "admin",
},
}
// Get existing metadata
userID, exists := conn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
// Get non-existent metadata
_, exists = conn.GetMetadata("non_existent")
assert.False(t, exists)
}
func TestConnection_AddSubscription(t *testing.T) {
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Entity: "users",
Active: true,
}
conn.AddSubscription(sub)
// Verify subscription was added
retrievedSub, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", retrievedSub.ID)
}
func TestConnection_RemoveSubscription(t *testing.T) {
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Entity: "users",
Active: true,
}
conn := &Connection{
subscriptions: map[string]*Subscription{
"sub-1": sub,
},
}
// Verify subscription exists
_, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
// Remove subscription
conn.RemoveSubscription("sub-1")
// Verify subscription was removed
_, exists = conn.GetSubscription("sub-1")
assert.False(t, exists)
}
func TestConnection_GetSubscription(t *testing.T) {
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
conn := &Connection{
subscriptions: map[string]*Subscription{
"sub-1": sub1,
"sub-2": sub2,
},
}
// Get existing subscription
retrievedSub, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", retrievedSub.ID)
// Get non-existent subscription
_, exists = conn.GetSubscription("non-existent")
assert.False(t, exists)
}
func TestConnection_MultipleSubscriptions(t *testing.T) {
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
sub3 := &Subscription{ID: "sub-3", Entity: "comments"}
conn.AddSubscription(sub1)
conn.AddSubscription(sub2)
conn.AddSubscription(sub3)
// Verify all subscriptions exist
_, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-2")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-3")
assert.True(t, exists)
// Remove one subscription
conn.RemoveSubscription("sub-2")
// Verify sub-2 is gone but others remain
_, exists = conn.GetSubscription("sub-2")
assert.False(t, exists)
_, exists = conn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-3")
assert.True(t, exists)
}
func TestBroadcastMessage_Structure(t *testing.T) {
msg := &BroadcastMessage{
Message: []byte("test message"),
Filter: func(conn *Connection) bool {
return true
},
}
assert.NotNil(t, msg.Message)
assert.NotNil(t, msg.Filter)
assert.Equal(t, "test message", string(msg.Message))
}
func TestBroadcastMessage_Filter(t *testing.T) {
// Filter that only allows specific connection
filter := func(conn *Connection) bool {
return conn.ID == "conn-1"
}
msg := &BroadcastMessage{
Message: []byte("test"),
Filter: filter,
}
conn1 := &Connection{ID: "conn-1"}
conn2 := &Connection{ID: "conn-2"}
assert.True(t, msg.Filter(conn1))
assert.False(t, msg.Filter(conn2))
}
func TestConnectionManager_Broadcast(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register connections
conn1 := &Connection{ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn2 := &Connection{ID: "conn-2", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
cm.Register(conn1)
cm.Register(conn2)
time.Sleep(10 * time.Millisecond)
// Broadcast message
message := []byte("test broadcast")
cm.Broadcast(message, nil)
time.Sleep(10 * time.Millisecond)
// Verify both connections received the message
select {
case msg := <-conn1.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn1 did not receive message")
}
select {
case msg := <-conn2.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn2 did not receive message")
}
}
func TestConnectionManager_BroadcastWithFilter(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register connections with metadata
conn1 := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: map[string]interface{}{"role": "admin"},
}
conn2 := &Connection{
ID: "conn-2",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: map[string]interface{}{"role": "user"},
}
cm.Register(conn1)
cm.Register(conn2)
time.Sleep(10 * time.Millisecond)
// Broadcast only to admins
filter := func(conn *Connection) bool {
role, _ := conn.GetMetadata("role")
return role == "admin"
}
message := []byte("admin message")
cm.Broadcast(message, filter)
time.Sleep(10 * time.Millisecond)
// Verify only conn1 received the message
select {
case msg := <-conn1.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn1 (admin) did not receive message")
}
// Verify conn2 did not receive the message
select {
case <-conn2.send:
t.Fatal("conn2 (user) should not have received admin message")
case <-time.After(50 * time.Millisecond):
// Expected - no message
}
}
func TestConnection_ConcurrentMetadataAccess(t *testing.T) {
// This test verifies that concurrent metadata access doesn't cause race conditions
// Run with: go test -race
conn := &Connection{
metadata: make(map[string]interface{}),
}
done := make(chan bool)
// Goroutine 1: Write metadata
go func() {
for i := 0; i < 100; i++ {
conn.SetMetadata("key", i)
}
done <- true
}()
// Goroutine 2: Read metadata
go func() {
for i := 0; i < 100; i++ {
conn.GetMetadata("key")
}
done <- true
}()
// Wait for completion
<-done
<-done
}
func TestConnection_ConcurrentSubscriptionAccess(t *testing.T) {
// This test verifies that concurrent subscription access doesn't cause race conditions
// Run with: go test -race
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
done := make(chan bool)
// Goroutine 1: Add subscriptions
go func() {
for i := 0; i < 100; i++ {
sub := &Subscription{ID: "sub-" + string(rune(i)), Entity: "users"}
conn.AddSubscription(sub)
}
done <- true
}()
// Goroutine 2: Get subscriptions
go func() {
for i := 0; i < 100; i++ {
conn.GetSubscription("sub-" + string(rune(i)))
}
done <- true
}()
// Wait for completion
<-done
<-done
}
func TestConnectionManager_CompleteLifecycle(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Create and register connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: make(map[string]interface{}),
}
// Set metadata
conn.SetMetadata("user_id", 123)
// Add subscriptions
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
conn.AddSubscription(sub1)
conn.AddSubscription(sub2)
// Register connection
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
// Verify connection exists
retrievedConn, exists := cm.GetConnection("conn-1")
require.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
// Verify metadata
userID, exists := retrievedConn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
// Verify subscriptions
_, exists = retrievedConn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = retrievedConn.GetSubscription("sub-2")
assert.True(t, exists)
// Broadcast message
message := []byte("test message")
cm.Broadcast(message, nil)
time.Sleep(10 * time.Millisecond)
select {
case msg := <-retrievedConn.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("Connection did not receive broadcast")
}
// Unregister connection
cm.Unregister(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 0, cm.Count())
// Verify connection is gone
_, exists = cm.GetConnection("conn-1")
assert.False(t, exists)
}

View File

@@ -0,0 +1,237 @@
package websocketspec_test
import (
"fmt"
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/websocketspec"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// User model example
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}
// Post model example
type Post struct {
ID uint `json:"id" gorm:"primaryKey"`
Title string `json:"title"`
Content string `json:"content"`
UserID uint `json:"user_id"`
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
}
// Example_basicSetup demonstrates basic WebSocketSpec setup
func Example_basicSetup() {
// Connect to database
db, err := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
if err != nil {
log.Fatal(err)
}
// Create WebSocket handler
handler := websocketspec.NewHandlerWithGORM(db)
// Register models
handler.Registry().RegisterModel("public.users", &User{})
handler.Registry().RegisterModel("public.posts", &Post{})
// Setup WebSocket endpoint
http.HandleFunc("/ws", handler.HandleWebSocket)
// Start server
log.Println("WebSocket server starting on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal(err)
}
}
// Example_withHooks demonstrates using lifecycle hooks
func Example_withHooks() {
db, _ := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
handler := websocketspec.NewHandlerWithGORM(db)
// Register models
handler.Registry().RegisterModel("public.users", &User{})
// Add authentication hook
handler.Hooks().Register(websocketspec.BeforeConnect, func(ctx *websocketspec.HookContext) error {
// Validate authentication token
// (In real implementation, extract from query params or headers)
userID := uint(123) // From token
// Store in connection metadata
ctx.Connection.SetMetadata("user_id", userID)
log.Printf("User %d connected", userID)
return nil
})
// Add authorization hook for read operations
handler.Hooks().RegisterBefore(websocketspec.OperationRead, func(ctx *websocketspec.HookContext) error {
userID, ok := ctx.Connection.GetMetadata("user_id")
if !ok {
return fmt.Errorf("unauthorized: not authenticated")
}
log.Printf("User %v reading %s.%s", userID, ctx.Schema, ctx.Entity)
// Add filter to only show user's own records
if ctx.Entity == "posts" {
// ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
// Column: "user_id",
// Operator: "eq",
// Value: userID,
// })
}
return nil
})
// Add logging hook after create
handler.Hooks().RegisterAfter(websocketspec.OperationCreate, func(ctx *websocketspec.HookContext) error {
userID, _ := ctx.Connection.GetMetadata("user_id")
log.Printf("User %v created record in %s.%s", userID, ctx.Schema, ctx.Entity)
return nil
})
// Add validation hook before create
handler.Hooks().RegisterBefore(websocketspec.OperationCreate, func(ctx *websocketspec.HookContext) error {
// Validate required fields
if data, ok := ctx.Data.(map[string]interface{}); ok {
if ctx.Entity == "users" {
if email, exists := data["email"]; !exists || email == "" {
return fmt.Errorf("validation error: email is required")
}
if name, exists := data["name"]; !exists || name == "" {
return fmt.Errorf("validation error: name is required")
}
}
}
return nil
})
// Add limit hook for subscriptions
handler.Hooks().Register(websocketspec.BeforeSubscribe, func(ctx *websocketspec.HookContext) error {
// Limit subscriptions per connection
maxSubscriptions := 10
// Note: In a real implementation, you would count subscriptions using the connection's methods
// currentCount := len(ctx.Connection.subscriptions) // subscriptions is private
// For demonstration purposes, we'll just log
log.Printf("Creating subscription (max: %d)", maxSubscriptions)
return nil
})
http.HandleFunc("/ws", handler.HandleWebSocket)
log.Println("Server with hooks starting on :8080")
http.ListenAndServe(":8080", nil)
}
// Example_monitoring demonstrates monitoring connections and subscriptions
func Example_monitoring() {
db, _ := gorm.Open(postgres.Open("your-connection-string"), &gorm.Config{})
handler := websocketspec.NewHandlerWithGORM(db)
handler.Registry().RegisterModel("public.users", &User{})
// Add connection tracking
handler.Hooks().Register(websocketspec.AfterConnect, func(ctx *websocketspec.HookContext) error {
count := handler.GetConnectionCount()
log.Printf("Client connected. Total connections: %d", count)
return nil
})
handler.Hooks().Register(websocketspec.AfterDisconnect, func(ctx *websocketspec.HookContext) error {
count := handler.GetConnectionCount()
log.Printf("Client disconnected. Total connections: %d", count)
return nil
})
// Add subscription tracking
handler.Hooks().Register(websocketspec.AfterSubscribe, func(ctx *websocketspec.HookContext) error {
count := handler.GetSubscriptionCount()
log.Printf("New subscription. Total subscriptions: %d", count)
return nil
})
// Monitoring endpoint
http.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Active Connections: %d\n", handler.GetConnectionCount())
fmt.Fprintf(w, "Active Subscriptions: %d\n", handler.GetSubscriptionCount())
})
http.HandleFunc("/ws", handler.HandleWebSocket)
log.Println("Server with monitoring starting on :8080")
http.ListenAndServe(":8080", nil)
}
// Example_clientSide shows client-side usage example
func Example_clientSide() {
// This is JavaScript code for documentation purposes
jsCode := `
// JavaScript WebSocket Client Example
const ws = new WebSocket("ws://localhost:8080/ws");
ws.onopen = () => {
console.log("Connected to WebSocket");
// Read users
ws.send(JSON.stringify({
id: "msg-1",
type: "request",
operation: "read",
schema: "public",
entity: "users",
options: {
filters: [{column: "status", operator: "eq", value: "active"}],
limit: 10
}
}));
// Subscribe to user changes
ws.send(JSON.stringify({
id: "sub-1",
type: "subscription",
operation: "subscribe",
schema: "public",
entity: "users",
options: {
filters: [{column: "status", operator: "eq", value: "active"}]
}
}));
};
ws.onmessage = (event) => {
const message = JSON.parse(event.data);
if (message.type === "response") {
if (message.success) {
console.log("Response:", message.data);
} else {
console.error("Error:", message.error);
}
} else if (message.type === "notification") {
console.log("Notification:", message.operation, message.data);
}
};
ws.onerror = (error) => {
console.error("WebSocket error:", error);
};
ws.onclose = () => {
console.log("WebSocket connection closed");
// Implement reconnection logic here
};
`
fmt.Println(jsCode)
}

View File

@@ -0,0 +1,747 @@
package websocketspec
import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strconv"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler handles WebSocket connections and messages
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
nestedProcessor *common.NestedCUDProcessor
connManager *ConnectionManager
subscriptionManager *SubscriptionManager
upgrader websocket.Upgrader
ctx context.Context
}
// NewHandler creates a new WebSocket handler
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
ctx := context.Background()
handler := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
connManager: NewConnectionManager(ctx),
subscriptionManager: NewSubscriptionManager(),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// TODO: Implement proper origin checking
return true
},
},
ctx: ctx,
}
// Initialize nested processor (nil for now, can be added later if needed)
// handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
// Start connection manager
go handler.connManager.Run()
return handler
}
// GetRelationshipInfo implements the RelationshipInfoProvider interface
// This is a placeholder implementation - full relationship support can be added later
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
// TODO: Implement full relationship detection similar to restheadspec
return nil
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// Hooks returns the hook registry for this handler
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// Registry returns the model registry for this handler
func (h *Handler) Registry() common.ModelRegistry {
return h.registry
}
// HandleWebSocket upgrades HTTP connection to WebSocket
func (h *Handler) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// Upgrade connection
ws, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error("[WebSocketSpec] Failed to upgrade connection: %v", err)
return
}
// Create connection
connID := uuid.New().String()
conn := NewConnection(connID, ws, h)
// Execute before connect hook
hookCtx := &HookContext{
Context: r.Context(),
Handler: h,
Connection: conn,
}
if err := h.hooks.Execute(BeforeConnect, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeConnect hook failed: %v", err)
ws.Close()
return
}
// Register connection
h.connManager.Register(conn)
// Execute after connect hook
h.hooks.Execute(AfterConnect, hookCtx)
// Start read/write pumps
go conn.WritePump()
go conn.ReadPump()
logger.Info("[WebSocketSpec] WebSocket connection established: %s", connID)
}
// HandleMessage routes incoming messages to appropriate handlers
func (h *Handler) HandleMessage(conn *Connection, msg *Message) {
switch msg.Type {
case MessageTypeRequest:
h.handleRequest(conn, msg)
case MessageTypeSubscription:
h.handleSubscription(conn, msg)
case MessageTypePing:
h.handlePing(conn, msg)
default:
errResp := NewErrorResponse(msg.ID, "invalid_message_type", fmt.Sprintf("Unknown message type: %s", msg.Type))
conn.SendJSON(errResp)
}
}
// handleRequest processes a request message
func (h *Handler) handleRequest(conn *Connection, msg *Message) {
ctx := conn.ctx
schema := msg.Schema
entity := msg.Entity
recordID := msg.RecordID
// Get model from registry
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("[WebSocketSpec] Model not found for %s.%s: %v", schema, entity, err)
errResp := NewErrorResponse(msg.ID, "model_not_found", fmt.Sprintf("Model not found: %s.%s", schema, entity))
conn.SendJSON(errResp)
return
}
// Validate and unwrap model
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
logger.Error("[WebSocketSpec] Model validation failed for %s.%s: %v", schema, entity, err)
errResp := NewErrorResponse(msg.ID, "invalid_model", err.Error())
conn.SendJSON(errResp)
return
}
model = result.Model
modelPtr := result.ModelPtr
tableName := h.getTableName(schema, entity, model)
// Create hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Connection: conn,
Message: msg,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
ModelPtr: modelPtr,
Options: msg.Options,
ID: recordID,
Data: msg.Data,
Metadata: make(map[string]interface{}),
}
// Route to operation handler
switch msg.Operation {
case OperationRead:
h.handleRead(conn, msg, hookCtx)
case OperationCreate:
h.handleCreate(conn, msg, hookCtx)
case OperationUpdate:
h.handleUpdate(conn, msg, hookCtx)
case OperationDelete:
h.handleDelete(conn, msg, hookCtx)
case OperationMeta:
h.handleMeta(conn, msg, hookCtx)
default:
errResp := NewErrorResponse(msg.ID, "invalid_operation", fmt.Sprintf("Unknown operation: %s", msg.Operation))
conn.SendJSON(errResp)
}
}
// handleRead processes a read operation
func (h *Handler) handleRead(conn *Connection, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeRead hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Perform read operation
var data interface{}
var metadata map[string]interface{}
var err error
if hookCtx.ID != "" {
// Read single record by ID
data, err = h.readByID(hookCtx)
metadata = map[string]interface{}{"total": 1}
} else {
// Read multiple records
data, metadata, err = h.readMultiple(hookCtx)
}
if err != nil {
logger.Error("[WebSocketSpec] Read operation failed: %v", err)
errResp := NewErrorResponse(msg.ID, "read_error", err.Error())
conn.SendJSON(errResp)
return
}
// Update hook context with result
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
logger.Error("[WebSocketSpec] AfterRead hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Send response
resp := NewResponseMessage(msg.ID, true, hookCtx.Result)
resp.Metadata = metadata
conn.SendJSON(resp)
}
// handleCreate processes a create operation
func (h *Handler) handleCreate(conn *Connection, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeCreate hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Perform create operation
data, err := h.create(hookCtx)
if err != nil {
logger.Error("[WebSocketSpec] Create operation failed: %v", err)
errResp := NewErrorResponse(msg.ID, "create_error", err.Error())
conn.SendJSON(errResp)
return
}
// Update hook context
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("[WebSocketSpec] AfterCreate hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Send response
resp := NewResponseMessage(msg.ID, true, hookCtx.Result)
conn.SendJSON(resp)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationCreate, data)
}
// handleUpdate processes an update operation
func (h *Handler) handleUpdate(conn *Connection, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeUpdate hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Perform update operation
data, err := h.update(hookCtx)
if err != nil {
logger.Error("[WebSocketSpec] Update operation failed: %v", err)
errResp := NewErrorResponse(msg.ID, "update_error", err.Error())
conn.SendJSON(errResp)
return
}
// Update hook context
hookCtx.Result = data
// Execute after hook
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("[WebSocketSpec] AfterUpdate hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Send response
resp := NewResponseMessage(msg.ID, true, hookCtx.Result)
conn.SendJSON(resp)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationUpdate, data)
}
// handleDelete processes a delete operation
func (h *Handler) handleDelete(conn *Connection, msg *Message, hookCtx *HookContext) {
// Execute before hook
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeDelete hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Perform delete operation
err := h.delete(hookCtx)
if err != nil {
logger.Error("[WebSocketSpec] Delete operation failed: %v", err)
errResp := NewErrorResponse(msg.ID, "delete_error", err.Error())
conn.SendJSON(errResp)
return
}
// Execute after hook
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
logger.Error("[WebSocketSpec] AfterDelete hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Send response
resp := NewResponseMessage(msg.ID, true, map[string]interface{}{"deleted": true})
conn.SendJSON(resp)
// Notify subscribers
h.notifySubscribers(hookCtx.Schema, hookCtx.Entity, OperationDelete, map[string]interface{}{"id": hookCtx.ID})
}
// handleMeta processes a metadata request
func (h *Handler) handleMeta(conn *Connection, msg *Message, hookCtx *HookContext) {
metadata := h.getMetadata(hookCtx.Schema, hookCtx.Entity, hookCtx.Model)
resp := NewResponseMessage(msg.ID, true, metadata)
conn.SendJSON(resp)
}
// handleSubscription processes subscription messages
func (h *Handler) handleSubscription(conn *Connection, msg *Message) {
switch msg.Operation {
case OperationSubscribe:
h.handleSubscribe(conn, msg)
case OperationUnsubscribe:
h.handleUnsubscribe(conn, msg)
default:
errResp := NewErrorResponse(msg.ID, "invalid_subscription_operation", fmt.Sprintf("Unknown subscription operation: %s", msg.Operation))
conn.SendJSON(errResp)
}
}
// handleSubscribe creates a new subscription
func (h *Handler) handleSubscribe(conn *Connection, msg *Message) {
// Generate subscription ID
subID := uuid.New().String()
// Create hook context
hookCtx := &HookContext{
Context: conn.ctx,
Handler: h,
Connection: conn,
Message: msg,
Schema: msg.Schema,
Entity: msg.Entity,
Options: msg.Options,
Metadata: make(map[string]interface{}),
}
// Execute before hook
if err := h.hooks.Execute(BeforeSubscribe, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeSubscribe hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Create subscription
sub := h.subscriptionManager.Subscribe(subID, conn.ID, msg.Schema, msg.Entity, msg.Options)
conn.AddSubscription(sub)
// Update hook context
hookCtx.Subscription = sub
// Execute after hook
h.hooks.Execute(AfterSubscribe, hookCtx)
// Send response
resp := NewResponseMessage(msg.ID, true, map[string]interface{}{
"subscription_id": subID,
"schema": msg.Schema,
"entity": msg.Entity,
})
conn.SendJSON(resp)
logger.Info("[WebSocketSpec] Subscription created: %s for %s.%s (conn: %s)", subID, msg.Schema, msg.Entity, conn.ID)
}
// handleUnsubscribe removes a subscription
func (h *Handler) handleUnsubscribe(conn *Connection, msg *Message) {
subID := msg.SubscriptionID
if subID == "" {
errResp := NewErrorResponse(msg.ID, "missing_subscription_id", "Subscription ID is required for unsubscribe")
conn.SendJSON(errResp)
return
}
// Get subscription
sub, exists := conn.GetSubscription(subID)
if !exists {
errResp := NewErrorResponse(msg.ID, "subscription_not_found", fmt.Sprintf("Subscription not found: %s", subID))
conn.SendJSON(errResp)
return
}
// Create hook context
hookCtx := &HookContext{
Context: conn.ctx,
Handler: h,
Connection: conn,
Message: msg,
Subscription: sub,
Metadata: make(map[string]interface{}),
}
// Execute before hook
if err := h.hooks.Execute(BeforeUnsubscribe, hookCtx); err != nil {
logger.Error("[WebSocketSpec] BeforeUnsubscribe hook failed: %v", err)
errResp := NewErrorResponse(msg.ID, "hook_error", err.Error())
conn.SendJSON(errResp)
return
}
// Remove subscription
h.subscriptionManager.Unsubscribe(subID)
conn.RemoveSubscription(subID)
// Execute after hook
h.hooks.Execute(AfterUnsubscribe, hookCtx)
// Send response
resp := NewResponseMessage(msg.ID, true, map[string]interface{}{
"unsubscribed": true,
"subscription_id": subID,
})
conn.SendJSON(resp)
}
// handlePing responds to ping messages
func (h *Handler) handlePing(conn *Connection, msg *Message) {
pong := &Message{
ID: msg.ID,
Type: MessageTypePong,
Timestamp: time.Now(),
}
conn.SendJSON(pong)
}
// notifySubscribers sends notifications to all subscribers of an entity
func (h *Handler) notifySubscribers(schema, entity string, operation OperationType, data interface{}) {
subscriptions := h.subscriptionManager.GetSubscriptionsByEntity(schema, entity)
if len(subscriptions) == 0 {
return
}
for _, sub := range subscriptions {
// Check if data matches subscription filters
if !sub.MatchesFilters(data) {
continue
}
// Get connection
conn, exists := h.connManager.GetConnection(sub.ConnectionID)
if !exists {
continue
}
// Send notification
notification := NewNotificationMessage(sub.ID, operation, schema, entity, data)
if err := conn.SendJSON(notification); err != nil {
logger.Error("[WebSocketSpec] Failed to send notification to connection %s: %v", conn.ID, err)
}
}
}
// CRUD operation implementations
func (h *Handler) readByID(hookCtx *HookContext) (interface{}, error) {
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
// Apply columns
if hookCtx.Options != nil && len(hookCtx.Options.Columns) > 0 {
query = query.Column(hookCtx.Options.Columns...)
}
// Apply preloads (simplified for now)
if hookCtx.Options != nil {
for _, preload := range hookCtx.Options.Preload {
query = query.PreloadRelation(preload.Relation)
}
}
// Execute query
if err := query.ScanModel(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to read record: %w", err)
}
return hookCtx.ModelPtr, nil
}
func (h *Handler) readMultiple(hookCtx *HookContext) (interface{}, map[string]interface{}, error) {
query := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Apply options (simplified implementation)
if hookCtx.Options != nil {
// Apply filters
for _, filter := range hookCtx.Options.Filters {
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
}
// Apply sorting
for _, sort := range hookCtx.Options.Sort {
direction := "ASC"
if sort.Direction == "desc" {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Apply limit and offset
if hookCtx.Options.Limit != nil {
query = query.Limit(*hookCtx.Options.Limit)
}
if hookCtx.Options.Offset != nil {
query = query.Offset(*hookCtx.Options.Offset)
}
// Apply preloads
for _, preload := range hookCtx.Options.Preload {
query = query.PreloadRelation(preload.Relation)
}
// Apply columns
if len(hookCtx.Options.Columns) > 0 {
query = query.Column(hookCtx.Options.Columns...)
}
}
// Execute query
if err := query.ScanModel(hookCtx.Context); err != nil {
return nil, nil, fmt.Errorf("failed to read records: %w", err)
}
// Get count
metadata := make(map[string]interface{})
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
if hookCtx.Options != nil {
for _, filter := range hookCtx.Options.Filters {
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
}
}
count, _ := countQuery.Count(hookCtx.Context)
metadata["total"] = count
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
return hookCtx.ModelPtr, metadata, nil
}
func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
// Marshal and unmarshal data into model
dataBytes, err := json.Marshal(hookCtx.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal data: %w", err)
}
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
}
// Insert record
query := h.db.NewInsert().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
if _, err := query.Exec(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
return hookCtx.ModelPtr, nil
}
func (h *Handler) update(hookCtx *HookContext) (interface{}, error) {
// Marshal and unmarshal data into model
dataBytes, err := json.Marshal(hookCtx.Data)
if err != nil {
return nil, fmt.Errorf("failed to marshal data: %w", err)
}
if err := json.Unmarshal(dataBytes, hookCtx.ModelPtr); err != nil {
return nil, fmt.Errorf("failed to unmarshal data into model: %w", err)
}
// Update record
query := h.db.NewUpdate().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
if _, err := query.Exec(hookCtx.Context); err != nil {
return nil, fmt.Errorf("failed to update record: %w", err)
}
// Fetch updated record
return h.readByID(hookCtx)
}
func (h *Handler) delete(hookCtx *HookContext) error {
query := h.db.NewDelete().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
// Add ID filter
pkName := reflection.GetPrimaryKeyName(hookCtx.Model)
query = query.Where(fmt.Sprintf("%s = ?", pkName), hookCtx.ID)
if _, err := query.Exec(hookCtx.Context); err != nil {
return fmt.Errorf("failed to delete record: %w", err)
}
return nil
}
// Helper methods
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Use entity as table name
tableName := entity
if schema != "" {
tableName = schema + "." + tableName
}
return tableName
}
func (h *Handler) getMetadata(schema, entity string, model interface{}) map[string]interface{} {
metadata := make(map[string]interface{})
metadata["schema"] = schema
metadata["entity"] = entity
metadata["table_name"] = h.getTableName(schema, entity, model)
// Get fields from model using reflection
columns := reflection.GetModelColumns(model)
metadata["columns"] = columns
metadata["primary_key"] = reflection.GetPrimaryKeyName(model)
return metadata
}
// getOperatorSQL converts filter operator to SQL operator
func (h *Handler) getOperatorSQL(operator string) string {
switch operator {
case "eq":
return "="
case "neq":
return "!="
case "gt":
return ">"
case "gte":
return ">="
case "lt":
return "<"
case "lte":
return "<="
case "like":
return "LIKE"
case "ilike":
return "ILIKE"
case "in":
return "IN"
default:
return "="
}
}
// Shutdown gracefully shuts down the handler
func (h *Handler) Shutdown() {
h.connManager.Shutdown()
}
// GetConnectionCount returns the number of active connections
func (h *Handler) GetConnectionCount() int {
return h.connManager.Count()
}
// GetSubscriptionCount returns the number of active subscriptions
func (h *Handler) GetSubscriptionCount() int {
return h.subscriptionManager.Count()
}
// BroadcastMessage sends a message to all connections matching the filter
func (h *Handler) BroadcastMessage(message interface{}, filter func(*Connection) bool) error {
data, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
h.connManager.Broadcast(data, filter)
return nil
}
// GetConnection retrieves a connection by ID
func (h *Handler) GetConnection(id string) (*Connection, bool) {
return h.connManager.GetConnection(id)
}
// Helper to convert string ID to int64
func parseID(id string) (int64, error) {
return strconv.ParseInt(id, 10, 64)
}

View File

@@ -0,0 +1,823 @@
package websocketspec
import (
"context"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockDatabase is a mock implementation of common.Database for testing
type MockDatabase struct {
mock.Mock
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
args := m.Called()
return args.Get(0).(common.SelectQuery)
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
args := m.Called()
return args.Get(0).(common.InsertQuery)
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
args := m.Called()
return args.Get(0).(common.UpdateQuery)
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
args := m.Called()
return args.Get(0).(common.DeleteQuery)
}
func (m *MockDatabase) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
callArgs := m.Called(ctx, query, args)
if callArgs.Get(0) == nil {
return nil, callArgs.Error(1)
}
return callArgs.Get(0).(common.Result), callArgs.Error(1)
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callArgs := m.Called(ctx, dest, query, args)
return callArgs.Error(0)
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(common.Database), args.Error(1)
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
args := m.Called(ctx, fn)
return args.Error(0)
}
func (m *MockDatabase) GetUnderlyingDB() interface{} {
args := m.Called()
return args.Get(0)
}
// MockSelectQuery is a mock implementation of common.SelectQuery
type MockSelectQuery struct {
mock.Mock
}
func (m *MockSelectQuery) Model(model interface{}) common.SelectQuery {
args := m.Called(model)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Table(table string) common.SelectQuery {
args := m.Called(table)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Column(columns ...string) common.SelectQuery {
args := m.Called(columns)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
args := m.Called(column, values)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Order(order string) common.SelectQuery {
args := m.Called(order)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Limit(limit int) common.SelectQuery {
args := m.Called(limit)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Offset(offset int) common.SelectQuery {
args := m.Called(offset)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
args := m.Called(relation, apply)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
args := m.Called(relation, conditions)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
args := m.Called(relation, apply)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(order, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Group(group string) common.SelectQuery {
args := m.Called(group)
return args.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Having(having string, args ...interface{}) common.SelectQuery {
callArgs := m.Called(having, args)
return callArgs.Get(0).(common.SelectQuery)
}
func (m *MockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
args := m.Called(ctx, dest)
return args.Error(0)
}
func (m *MockSelectQuery) ScanModel(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockSelectQuery) Count(ctx context.Context) (int, error) {
args := m.Called(ctx)
return args.Int(0), args.Error(1)
}
func (m *MockSelectQuery) Exists(ctx context.Context) (bool, error) {
args := m.Called(ctx)
return args.Bool(0), args.Error(1)
}
// MockInsertQuery is a mock implementation of common.InsertQuery
type MockInsertQuery struct {
mock.Mock
}
func (m *MockInsertQuery) Model(model interface{}) common.InsertQuery {
args := m.Called(model)
return args.Get(0).(common.InsertQuery)
}
func (m *MockInsertQuery) Table(table string) common.InsertQuery {
args := m.Called(table)
return args.Get(0).(common.InsertQuery)
}
func (m *MockInsertQuery) Value(column string, value interface{}) common.InsertQuery {
args := m.Called(column, value)
return args.Get(0).(common.InsertQuery)
}
func (m *MockInsertQuery) OnConflict(action string) common.InsertQuery {
args := m.Called(action)
return args.Get(0).(common.InsertQuery)
}
func (m *MockInsertQuery) Returning(columns ...string) common.InsertQuery {
args := m.Called(columns)
return args.Get(0).(common.InsertQuery)
}
func (m *MockInsertQuery) Exec(ctx context.Context) (common.Result, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(common.Result), args.Error(1)
}
// MockUpdateQuery is a mock implementation of common.UpdateQuery
type MockUpdateQuery struct {
mock.Mock
}
func (m *MockUpdateQuery) Model(model interface{}) common.UpdateQuery {
args := m.Called(model)
return args.Get(0).(common.UpdateQuery)
}
func (m *MockUpdateQuery) Table(table string) common.UpdateQuery {
args := m.Called(table)
return args.Get(0).(common.UpdateQuery)
}
func (m *MockUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
args := m.Called(column, value)
return args.Get(0).(common.UpdateQuery)
}
func (m *MockUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
args := m.Called(values)
return args.Get(0).(common.UpdateQuery)
}
func (m *MockUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.UpdateQuery)
}
func (m *MockUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(common.Result), args.Error(1)
}
// MockDeleteQuery is a mock implementation of common.DeleteQuery
type MockDeleteQuery struct {
mock.Mock
}
func (m *MockDeleteQuery) Model(model interface{}) common.DeleteQuery {
args := m.Called(model)
return args.Get(0).(common.DeleteQuery)
}
func (m *MockDeleteQuery) Table(table string) common.DeleteQuery {
args := m.Called(table)
return args.Get(0).(common.DeleteQuery)
}
func (m *MockDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery {
callArgs := m.Called(query, args)
return callArgs.Get(0).(common.DeleteQuery)
}
func (m *MockDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(common.Result), args.Error(1)
}
// MockModelRegistry is a mock implementation of common.ModelRegistry
type MockModelRegistry struct {
mock.Mock
}
func (m *MockModelRegistry) RegisterModel(key string, model interface{}) error {
args := m.Called(key, model)
return args.Error(0)
}
func (m *MockModelRegistry) GetModel(key string) (interface{}, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0), args.Error(1)
}
func (m *MockModelRegistry) GetAllModels() map[string]interface{} {
args := m.Called()
return args.Get(0).(map[string]interface{})
}
func (m *MockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
args := m.Called(schema, entity)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0), args.Error(1)
}
// Test model
type TestUser struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}
func TestNewHandler(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
assert.NotNil(t, handler)
assert.NotNil(t, handler.db)
assert.NotNil(t, handler.registry)
assert.NotNil(t, handler.hooks)
assert.NotNil(t, handler.connManager)
assert.NotNil(t, handler.subscriptionManager)
assert.NotNil(t, handler.upgrader)
}
func TestHandler_Hooks(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
hooks := handler.Hooks()
assert.NotNil(t, hooks)
assert.IsType(t, &HookRegistry{}, hooks)
}
func TestHandler_Registry(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
registry := handler.Registry()
assert.NotNil(t, registry)
assert.Equal(t, mockRegistry, registry)
}
func TestHandler_GetDatabase(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
db := handler.GetDatabase()
assert.NotNil(t, db)
assert.Equal(t, mockDB, db)
}
func TestHandler_GetConnectionCount(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
count := handler.GetConnectionCount()
assert.Equal(t, 0, count)
}
func TestHandler_GetSubscriptionCount(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
count := handler.GetSubscriptionCount()
assert.Equal(t, 0, count)
}
func TestHandler_GetConnection(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Non-existent connection
_, exists := handler.GetConnection("non-existent")
assert.False(t, exists)
}
func TestHandler_HandleMessage_InvalidType(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
}
msg := &Message{
ID: "msg-1",
Type: MessageType("invalid"),
}
handler.HandleMessage(conn, msg)
// Should send error response
select {
case data := <-conn.send:
var response map[string]interface{}
require.NoError(t, ParseMessageBytes(data, &response))
assert.False(t, response["success"].(bool))
default:
t.Fatal("Expected error response")
}
}
func ParseMessageBytes(data []byte, v interface{}) error {
return nil // Simplified for testing
}
func TestHandler_GetOperatorSQL(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
tests := []struct {
operator string
expected string
}{
{"eq", "="},
{"neq", "!="},
{"gt", ">"},
{"gte", ">="},
{"lt", "<"},
{"lte", "<="},
{"like", "LIKE"},
{"ilike", "ILIKE"},
{"in", "IN"},
{"unknown", "="}, // default
}
for _, tt := range tests {
t.Run(tt.operator, func(t *testing.T) {
result := handler.getOperatorSQL(tt.operator)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHandler_GetTableName(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
tests := []struct {
name string
schema string
entity string
expected string
}{
{
name: "With schema",
schema: "public",
entity: "users",
expected: "public.users",
},
{
name: "Without schema",
schema: "",
entity: "users",
expected: "users",
},
{
name: "Different schema",
schema: "custom",
entity: "posts",
expected: "custom.posts",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.getTableName(tt.schema, tt.entity, &TestUser{})
assert.Equal(t, tt.expected, result)
})
}
}
func TestHandler_GetMetadata(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
metadata := handler.getMetadata("public", "users", &TestUser{})
assert.NotNil(t, metadata)
assert.Equal(t, "public", metadata["schema"])
assert.Equal(t, "users", metadata["entity"])
assert.Equal(t, "public.users", metadata["table_name"])
assert.NotNil(t, metadata["columns"])
assert.NotNil(t, metadata["primary_key"])
}
func TestHandler_NotifySubscribers(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Create connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
handler: handler,
}
// Register connection
handler.connManager.connections["conn-1"] = conn
// Create subscription
sub := handler.subscriptionManager.Subscribe("sub-1", "conn-1", "public", "users", nil)
conn.AddSubscription(sub)
// Notify subscribers
data := map[string]interface{}{"id": 1, "name": "John"}
handler.notifySubscribers("public", "users", OperationCreate, data)
// Verify notification was sent
select {
case msg := <-conn.send:
assert.NotEmpty(t, msg)
default:
t.Fatal("Expected notification to be sent")
}
}
func TestHandler_NotifySubscribers_NoSubscribers(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Notify with no subscribers - should not panic
data := map[string]interface{}{"id": 1, "name": "John"}
handler.notifySubscribers("public", "users", OperationCreate, data)
// No assertions needed - just checking it doesn't panic
}
func TestHandler_NotifySubscribers_ConnectionNotFound(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Create subscription without connection
handler.subscriptionManager.Subscribe("sub-1", "conn-1", "public", "users", nil)
// Notify - should handle gracefully when connection not found
data := map[string]interface{}{"id": 1, "name": "John"}
handler.notifySubscribers("public", "users", OperationCreate, data)
// No assertions needed - just checking it doesn't panic
}
func TestHandler_HooksIntegration(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
beforeCalled := false
afterCalled := false
// Register hooks
handler.Hooks().RegisterBefore(OperationCreate, func(ctx *HookContext) error {
beforeCalled = true
return nil
})
handler.Hooks().RegisterAfter(OperationCreate, func(ctx *HookContext) error {
afterCalled = true
return nil
})
// Verify hooks are registered
assert.True(t, handler.Hooks().HasHooks(BeforeCreate))
assert.True(t, handler.Hooks().HasHooks(AfterCreate))
// Execute hooks
ctx := &HookContext{Context: context.Background()}
handler.Hooks().Execute(BeforeCreate, ctx)
handler.Hooks().Execute(AfterCreate, ctx)
assert.True(t, beforeCalled)
assert.True(t, afterCalled)
}
func TestHandler_Shutdown(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Shutdown should not panic
handler.Shutdown()
// Verify context was cancelled
select {
case <-handler.connManager.ctx.Done():
// Expected
default:
t.Fatal("Connection manager context not cancelled after shutdown")
}
}
func TestHandler_SubscriptionLifecycle(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Create connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
handler: handler,
}
// Create subscription message
msg := &Message{
ID: "sub-msg-1",
Type: MessageTypeSubscription,
Operation: OperationSubscribe,
Schema: "public",
Entity: "users",
}
// Handle subscribe
handler.handleSubscribe(conn, msg)
// Verify subscription was created
assert.Equal(t, 1, handler.GetSubscriptionCount())
assert.Equal(t, 1, len(conn.subscriptions))
// Verify response was sent
select {
case data := <-conn.send:
assert.NotEmpty(t, data)
default:
t.Fatal("Expected subscription response")
}
}
func TestHandler_UnsubscribeLifecycle(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Create connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
handler: handler,
}
// Create subscription
sub := handler.subscriptionManager.Subscribe("sub-1", "conn-1", "public", "users", nil)
conn.AddSubscription(sub)
assert.Equal(t, 1, handler.GetSubscriptionCount())
// Create unsubscribe message
msg := &Message{
ID: "unsub-msg-1",
Type: MessageTypeSubscription,
Operation: OperationUnsubscribe,
SubscriptionID: "sub-1",
}
// Handle unsubscribe
handler.handleUnsubscribe(conn, msg)
// Verify subscription was removed
assert.Equal(t, 0, handler.GetSubscriptionCount())
assert.Equal(t, 0, len(conn.subscriptions))
// Verify response was sent
select {
case data := <-conn.send:
assert.NotEmpty(t, data)
default:
t.Fatal("Expected unsubscribe response")
}
}
func TestHandler_HandlePing(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
}
msg := &Message{
ID: "ping-1",
Type: MessageTypePing,
}
handler.handlePing(conn, msg)
// Verify pong was sent
select {
case data := <-conn.send:
assert.NotEmpty(t, data)
default:
t.Fatal("Expected pong response")
}
}
func TestHandler_CompleteWorkflow(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := NewHandler(mockDB, mockRegistry)
// Setup hooks (these are registered but not called in this test workflow)
handler.Hooks().RegisterBefore(OperationCreate, func(ctx *HookContext) error {
return nil
})
handler.Hooks().RegisterAfter(OperationCreate, func(ctx *HookContext) error {
return nil
})
// Create connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
handler: handler,
metadata: make(map[string]interface{}),
}
// Register connection
handler.connManager.connections["conn-1"] = conn
// Set user metadata
conn.SetMetadata("user_id", 123)
// Create subscription
subMsg := &Message{
ID: "sub-1",
Type: MessageTypeSubscription,
Operation: OperationSubscribe,
Schema: "public",
Entity: "users",
}
handler.handleSubscribe(conn, subMsg)
assert.Equal(t, 1, handler.GetSubscriptionCount())
// Clear send channel
select {
case <-conn.send:
default:
}
// Send ping
pingMsg := &Message{
ID: "ping-1",
Type: MessageTypePing,
}
handler.handlePing(conn, pingMsg)
// Verify pong was sent
select {
case <-conn.send:
// Expected
default:
t.Fatal("Expected pong response")
}
// Verify metadata
userID, exists := conn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
// Verify hooks were registered
assert.True(t, handler.Hooks().HasHooks(BeforeCreate))
assert.True(t, handler.Hooks().HasHooks(AfterCreate))
}

193
pkg/websocketspec/hooks.go Normal file
View File

@@ -0,0 +1,193 @@
package websocketspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// HookType represents the type of lifecycle hook
type HookType string
const (
// BeforeRead is called before a read operation
BeforeRead HookType = "before_read"
// AfterRead is called after a read operation
AfterRead HookType = "after_read"
// BeforeCreate is called before a create operation
BeforeCreate HookType = "before_create"
// AfterCreate is called after a create operation
AfterCreate HookType = "after_create"
// BeforeUpdate is called before an update operation
BeforeUpdate HookType = "before_update"
// AfterUpdate is called after an update operation
AfterUpdate HookType = "after_update"
// BeforeDelete is called before a delete operation
BeforeDelete HookType = "before_delete"
// AfterDelete is called after a delete operation
AfterDelete HookType = "after_delete"
// BeforeSubscribe is called before creating a subscription
BeforeSubscribe HookType = "before_subscribe"
// AfterSubscribe is called after creating a subscription
AfterSubscribe HookType = "after_subscribe"
// BeforeUnsubscribe is called before removing a subscription
BeforeUnsubscribe HookType = "before_unsubscribe"
// AfterUnsubscribe is called after removing a subscription
AfterUnsubscribe HookType = "after_unsubscribe"
// BeforeConnect is called when a new connection is established
BeforeConnect HookType = "before_connect"
// AfterConnect is called after a connection is established
AfterConnect HookType = "after_connect"
// BeforeDisconnect is called before a connection is closed
BeforeDisconnect HookType = "before_disconnect"
// AfterDisconnect is called after a connection is closed
AfterDisconnect HookType = "after_disconnect"
)
// HookContext contains context information for hook execution
type HookContext struct {
// Context is the request context
Context context.Context
// Handler provides access to the handler, database, and registry
Handler *Handler
// Connection is the WebSocket connection
Connection *Connection
// Message is the original message
Message *Message
// Schema is the database schema
Schema string
// Entity is the table/model name
Entity string
// TableName is the actual database table name
TableName string
// Model is the registered model instance
Model interface{}
// ModelPtr is a pointer to the model for queries
ModelPtr interface{}
// Options contains the parsed request options
Options *common.RequestOptions
// ID is the record ID for single-record operations
ID string
// Data is the request data (for create/update operations)
Data interface{}
// Result is the operation result (for after hooks)
Result interface{}
// Subscription is the subscription being created/removed
Subscription *Subscription
// Error is any error that occurred (for after hooks)
Error error
// Metadata is additional context data
Metadata map[string]interface{}
}
// HookFunc is a function that processes a hook
type HookFunc func(*HookContext) error
// HookRegistry manages lifecycle hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register registers a hook function for a specific hook type
func (hr *HookRegistry) Register(hookType HookType, fn HookFunc) {
hr.hooks[hookType] = append(hr.hooks[hookType], fn)
}
// RegisterBefore registers a hook that runs before an operation
// Convenience method for BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete
func (hr *HookRegistry) RegisterBefore(operation OperationType, fn HookFunc) {
switch operation {
case OperationRead:
hr.Register(BeforeRead, fn)
case OperationCreate:
hr.Register(BeforeCreate, fn)
case OperationUpdate:
hr.Register(BeforeUpdate, fn)
case OperationDelete:
hr.Register(BeforeDelete, fn)
case OperationSubscribe:
hr.Register(BeforeSubscribe, fn)
case OperationUnsubscribe:
hr.Register(BeforeUnsubscribe, fn)
}
}
// RegisterAfter registers a hook that runs after an operation
// Convenience method for AfterRead, AfterCreate, AfterUpdate, AfterDelete
func (hr *HookRegistry) RegisterAfter(operation OperationType, fn HookFunc) {
switch operation {
case OperationRead:
hr.Register(AfterRead, fn)
case OperationCreate:
hr.Register(AfterCreate, fn)
case OperationUpdate:
hr.Register(AfterUpdate, fn)
case OperationDelete:
hr.Register(AfterDelete, fn)
case OperationSubscribe:
hr.Register(AfterSubscribe, fn)
case OperationUnsubscribe:
hr.Register(AfterUnsubscribe, fn)
}
}
// Execute runs all hooks for a specific type
func (hr *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := hr.hooks[hookType]
if !exists {
return nil
}
for _, hook := range hooks {
if err := hook(ctx); err != nil {
return err
}
}
return nil
}
// HasHooks checks if any hooks are registered for a hook type
func (hr *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := hr.hooks[hookType]
return exists && len(hooks) > 0
}
// Clear removes all hooks of a specific type
func (hr *HookRegistry) Clear(hookType HookType) {
delete(hr.hooks, hookType)
}
// ClearAll removes all registered hooks
func (hr *HookRegistry) ClearAll() {
hr.hooks = make(map[HookType][]HookFunc)
}

View File

@@ -0,0 +1,547 @@
package websocketspec
import (
"context"
"errors"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHookType_Constants(t *testing.T) {
assert.Equal(t, HookType("before_read"), BeforeRead)
assert.Equal(t, HookType("after_read"), AfterRead)
assert.Equal(t, HookType("before_create"), BeforeCreate)
assert.Equal(t, HookType("after_create"), AfterCreate)
assert.Equal(t, HookType("before_update"), BeforeUpdate)
assert.Equal(t, HookType("after_update"), AfterUpdate)
assert.Equal(t, HookType("before_delete"), BeforeDelete)
assert.Equal(t, HookType("after_delete"), AfterDelete)
assert.Equal(t, HookType("before_subscribe"), BeforeSubscribe)
assert.Equal(t, HookType("after_subscribe"), AfterSubscribe)
assert.Equal(t, HookType("before_unsubscribe"), BeforeUnsubscribe)
assert.Equal(t, HookType("after_unsubscribe"), AfterUnsubscribe)
assert.Equal(t, HookType("before_connect"), BeforeConnect)
assert.Equal(t, HookType("after_connect"), AfterConnect)
assert.Equal(t, HookType("before_disconnect"), BeforeDisconnect)
assert.Equal(t, HookType("after_disconnect"), AfterDisconnect)
}
func TestNewHookRegistry(t *testing.T) {
hr := NewHookRegistry()
assert.NotNil(t, hr)
assert.NotNil(t, hr.hooks)
assert.Empty(t, hr.hooks)
}
func TestHookRegistry_Register(t *testing.T) {
hr := NewHookRegistry()
hookCalled := false
hook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
hr.Register(BeforeRead, hook)
// Verify hook was registered
assert.True(t, hr.HasHooks(BeforeRead))
// Execute hook
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(BeforeRead, ctx)
require.NoError(t, err)
assert.True(t, hookCalled)
}
func TestHookRegistry_Register_MultipleHooks(t *testing.T) {
hr := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
hr.Register(BeforeRead, hook1)
hr.Register(BeforeRead, hook2)
hr.Register(BeforeRead, hook3)
// Execute hooks
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(BeforeRead, ctx)
require.NoError(t, err)
// Verify hooks were called in order
assert.Equal(t, []int{1, 2, 3}, callOrder)
}
func TestHookRegistry_RegisterBefore(t *testing.T) {
hr := NewHookRegistry()
tests := []struct {
operation OperationType
hookType HookType
}{
{OperationRead, BeforeRead},
{OperationCreate, BeforeCreate},
{OperationUpdate, BeforeUpdate},
{OperationDelete, BeforeDelete},
{OperationSubscribe, BeforeSubscribe},
{OperationUnsubscribe, BeforeUnsubscribe},
}
for _, tt := range tests {
t.Run(string(tt.operation), func(t *testing.T) {
hookCalled := false
hook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
hr.RegisterBefore(tt.operation, hook)
assert.True(t, hr.HasHooks(tt.hookType))
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(tt.hookType, ctx)
require.NoError(t, err)
assert.True(t, hookCalled)
// Clean up for next test
hr.Clear(tt.hookType)
})
}
}
func TestHookRegistry_RegisterAfter(t *testing.T) {
hr := NewHookRegistry()
tests := []struct {
operation OperationType
hookType HookType
}{
{OperationRead, AfterRead},
{OperationCreate, AfterCreate},
{OperationUpdate, AfterUpdate},
{OperationDelete, AfterDelete},
{OperationSubscribe, AfterSubscribe},
{OperationUnsubscribe, AfterUnsubscribe},
}
for _, tt := range tests {
t.Run(string(tt.operation), func(t *testing.T) {
hookCalled := false
hook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
hr.RegisterAfter(tt.operation, hook)
assert.True(t, hr.HasHooks(tt.hookType))
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(tt.hookType, ctx)
require.NoError(t, err)
assert.True(t, hookCalled)
// Clean up for next test
hr.Clear(tt.hookType)
})
}
}
func TestHookRegistry_Execute_NoHooks(t *testing.T) {
hr := NewHookRegistry()
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(BeforeRead, ctx)
// Should not error when no hooks registered
assert.NoError(t, err)
}
func TestHookRegistry_Execute_HookReturnsError(t *testing.T) {
hr := NewHookRegistry()
expectedErr := errors.New("hook error")
hook := func(ctx *HookContext) error {
return expectedErr
}
hr.Register(BeforeRead, hook)
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(BeforeRead, ctx)
assert.Error(t, err)
assert.Equal(t, expectedErr, err)
}
func TestHookRegistry_Execute_FirstHookErrors(t *testing.T) {
hr := NewHookRegistry()
hook1Called := false
hook2Called := false
hook1 := func(ctx *HookContext) error {
hook1Called = true
return errors.New("hook1 error")
}
hook2 := func(ctx *HookContext) error {
hook2Called = true
return nil
}
hr.Register(BeforeRead, hook1)
hr.Register(BeforeRead, hook2)
ctx := &HookContext{Context: context.Background()}
err := hr.Execute(BeforeRead, ctx)
assert.Error(t, err)
assert.True(t, hook1Called)
assert.False(t, hook2Called) // Should not be called after first error
}
func TestHookRegistry_HasHooks(t *testing.T) {
hr := NewHookRegistry()
assert.False(t, hr.HasHooks(BeforeRead))
hr.Register(BeforeRead, func(ctx *HookContext) error { return nil })
assert.True(t, hr.HasHooks(BeforeRead))
assert.False(t, hr.HasHooks(AfterRead))
}
func TestHookRegistry_Clear(t *testing.T) {
hr := NewHookRegistry()
hr.Register(BeforeRead, func(ctx *HookContext) error { return nil })
hr.Register(BeforeRead, func(ctx *HookContext) error { return nil })
assert.True(t, hr.HasHooks(BeforeRead))
hr.Clear(BeforeRead)
assert.False(t, hr.HasHooks(BeforeRead))
}
func TestHookRegistry_ClearAll(t *testing.T) {
hr := NewHookRegistry()
hr.Register(BeforeRead, func(ctx *HookContext) error { return nil })
hr.Register(AfterRead, func(ctx *HookContext) error { return nil })
hr.Register(BeforeCreate, func(ctx *HookContext) error { return nil })
assert.True(t, hr.HasHooks(BeforeRead))
assert.True(t, hr.HasHooks(AfterRead))
assert.True(t, hr.HasHooks(BeforeCreate))
hr.ClearAll()
assert.False(t, hr.HasHooks(BeforeRead))
assert.False(t, hr.HasHooks(AfterRead))
assert.False(t, hr.HasHooks(BeforeCreate))
}
func TestHookContext_Structure(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Schema: "public",
Entity: "users",
TableName: "public.users",
ID: "123",
Data: map[string]interface{}{
"name": "John",
},
Options: &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
},
Metadata: map[string]interface{}{
"user_id": 456,
},
}
assert.NotNil(t, ctx.Context)
assert.Equal(t, "public", ctx.Schema)
assert.Equal(t, "users", ctx.Entity)
assert.Equal(t, "public.users", ctx.TableName)
assert.Equal(t, "123", ctx.ID)
assert.NotNil(t, ctx.Data)
assert.NotNil(t, ctx.Options)
assert.NotNil(t, ctx.Metadata)
}
func TestHookContext_ModifyData(t *testing.T) {
hr := NewHookRegistry()
// Hook that modifies data
hook := func(ctx *HookContext) error {
if data, ok := ctx.Data.(map[string]interface{}); ok {
data["modified"] = true
}
return nil
}
hr.Register(BeforeCreate, hook)
ctx := &HookContext{
Context: context.Background(),
Data: map[string]interface{}{
"name": "John",
},
}
err := hr.Execute(BeforeCreate, ctx)
require.NoError(t, err)
// Verify data was modified
data := ctx.Data.(map[string]interface{})
assert.True(t, data["modified"].(bool))
}
func TestHookContext_ModifyOptions(t *testing.T) {
hr := NewHookRegistry()
// Hook that adds a filter
hook := func(ctx *HookContext) error {
if ctx.Options == nil {
ctx.Options = &common.RequestOptions{}
}
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "user_id",
Operator: "eq",
Value: 123,
})
return nil
}
hr.Register(BeforeRead, hook)
ctx := &HookContext{
Context: context.Background(),
Options: &common.RequestOptions{},
}
err := hr.Execute(BeforeRead, ctx)
require.NoError(t, err)
// Verify filter was added
assert.Len(t, ctx.Options.Filters, 1)
assert.Equal(t, "user_id", ctx.Options.Filters[0].Column)
}
func TestHookContext_UseMetadata(t *testing.T) {
hr := NewHookRegistry()
// Hook that stores data in metadata
hook := func(ctx *HookContext) error {
ctx.Metadata["processed"] = true
ctx.Metadata["timestamp"] = "2024-01-01"
return nil
}
hr.Register(BeforeCreate, hook)
ctx := &HookContext{
Context: context.Background(),
Metadata: make(map[string]interface{}),
}
err := hr.Execute(BeforeCreate, ctx)
require.NoError(t, err)
// Verify metadata was set
assert.True(t, ctx.Metadata["processed"].(bool))
assert.Equal(t, "2024-01-01", ctx.Metadata["timestamp"])
}
func TestHookRegistry_Authentication_Example(t *testing.T) {
hr := NewHookRegistry()
// Authentication hook
authHook := func(ctx *HookContext) error {
// Simulate getting user from connection metadata
userID := 123
ctx.Metadata["user_id"] = userID
return nil
}
// Authorization hook that uses auth data
authzHook := func(ctx *HookContext) error {
userID, ok := ctx.Metadata["user_id"]
if !ok {
return errors.New("unauthorized: not authenticated")
}
// Add filter to only show user's own records
if ctx.Options == nil {
ctx.Options = &common.RequestOptions{}
}
ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
Column: "user_id",
Operator: "eq",
Value: userID,
})
return nil
}
hr.Register(BeforeConnect, authHook)
hr.Register(BeforeRead, authzHook)
// Simulate connection
ctx1 := &HookContext{
Context: context.Background(),
Metadata: make(map[string]interface{}),
}
err := hr.Execute(BeforeConnect, ctx1)
require.NoError(t, err)
assert.Equal(t, 123, ctx1.Metadata["user_id"])
// Simulate read with authorization
ctx2 := &HookContext{
Context: context.Background(),
Metadata: map[string]interface{}{"user_id": 123},
Options: &common.RequestOptions{},
}
err = hr.Execute(BeforeRead, ctx2)
require.NoError(t, err)
assert.Len(t, ctx2.Options.Filters, 1)
assert.Equal(t, "user_id", ctx2.Options.Filters[0].Column)
}
func TestHookRegistry_Validation_Example(t *testing.T) {
hr := NewHookRegistry()
// Validation hook
validationHook := func(ctx *HookContext) error {
data, ok := ctx.Data.(map[string]interface{})
if !ok {
return errors.New("invalid data format")
}
if ctx.Entity == "users" {
email, hasEmail := data["email"]
if !hasEmail || email == "" {
return errors.New("validation error: email is required")
}
name, hasName := data["name"]
if !hasName || name == "" {
return errors.New("validation error: name is required")
}
}
return nil
}
hr.Register(BeforeCreate, validationHook)
// Test with valid data
ctx1 := &HookContext{
Context: context.Background(),
Entity: "users",
Data: map[string]interface{}{
"name": "John Doe",
"email": "john@example.com",
},
}
err := hr.Execute(BeforeCreate, ctx1)
assert.NoError(t, err)
// Test with missing email
ctx2 := &HookContext{
Context: context.Background(),
Entity: "users",
Data: map[string]interface{}{
"name": "John Doe",
},
}
err = hr.Execute(BeforeCreate, ctx2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "email is required")
// Test with missing name
ctx3 := &HookContext{
Context: context.Background(),
Entity: "users",
Data: map[string]interface{}{
"email": "john@example.com",
},
}
err = hr.Execute(BeforeCreate, ctx3)
assert.Error(t, err)
assert.Contains(t, err.Error(), "name is required")
}
func TestHookRegistry_Logging_Example(t *testing.T) {
hr := NewHookRegistry()
logEntries := []string{}
// Logging hook for create operations
loggingHook := func(ctx *HookContext) error {
logEntries = append(logEntries, "Created record in "+ctx.Entity)
return nil
}
hr.Register(AfterCreate, loggingHook)
ctx := &HookContext{
Context: context.Background(),
Entity: "users",
Result: map[string]interface{}{"id": 1, "name": "John"},
}
err := hr.Execute(AfterCreate, ctx)
require.NoError(t, err)
assert.Len(t, logEntries, 1)
assert.Equal(t, "Created record in users", logEntries[0])
}
func TestHookRegistry_ConcurrentExecution(t *testing.T) {
hr := NewHookRegistry()
// This test verifies that concurrent hook executions don't cause race conditions
// Run with: go test -race
counter := 0
hook := func(ctx *HookContext) error {
counter++
return nil
}
hr.Register(BeforeRead, hook)
done := make(chan bool)
// Execute hooks concurrently
for i := 0; i < 10; i++ {
go func() {
ctx := &HookContext{Context: context.Background()}
hr.Execute(BeforeRead, ctx)
done <- true
}()
}
// Wait for all executions
for i := 0; i < 10; i++ {
<-done
}
assert.Equal(t, 10, counter)
}

View File

@@ -0,0 +1,240 @@
package websocketspec
import (
"encoding/json"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// MessageType represents the type of WebSocket message
type MessageType string
const (
// MessageTypeRequest is a client request message
MessageTypeRequest MessageType = "request"
// MessageTypeResponse is a server response message
MessageTypeResponse MessageType = "response"
// MessageTypeNotification is a server-initiated notification
MessageTypeNotification MessageType = "notification"
// MessageTypeSubscription is a subscription control message
MessageTypeSubscription MessageType = "subscription"
// MessageTypeError is an error message
MessageTypeError MessageType = "error"
// MessageTypePing is a keepalive ping message
MessageTypePing MessageType = "ping"
// MessageTypePong is a keepalive pong response
MessageTypePong MessageType = "pong"
)
// OperationType represents the operation to perform
type OperationType string
const (
// OperationRead retrieves records
OperationRead OperationType = "read"
// OperationCreate creates a new record
OperationCreate OperationType = "create"
// OperationUpdate updates an existing record
OperationUpdate OperationType = "update"
// OperationDelete deletes a record
OperationDelete OperationType = "delete"
// OperationSubscribe subscribes to entity changes
OperationSubscribe OperationType = "subscribe"
// OperationUnsubscribe unsubscribes from entity changes
OperationUnsubscribe OperationType = "unsubscribe"
// OperationMeta retrieves metadata about an entity
OperationMeta OperationType = "meta"
)
// Message represents a WebSocket message
type Message struct {
// ID is a unique identifier for request/response correlation
ID string `json:"id,omitempty"`
// Type is the message type
Type MessageType `json:"type"`
// Operation is the operation to perform
Operation OperationType `json:"operation,omitempty"`
// Schema is the database schema name
Schema string `json:"schema,omitempty"`
// Entity is the table/model name
Entity string `json:"entity,omitempty"`
// RecordID is the ID for single-record operations (update, delete, read by ID)
RecordID string `json:"record_id,omitempty"`
// Data contains the request/response payload
Data interface{} `json:"data,omitempty"`
// Options contains query options (filters, sorting, pagination, etc.)
Options *common.RequestOptions `json:"options,omitempty"`
// SubscriptionID is the subscription identifier
SubscriptionID string `json:"subscription_id,omitempty"`
// Success indicates if the operation was successful
Success bool `json:"success,omitempty"`
// Error contains error information
Error *ErrorInfo `json:"error,omitempty"`
// Metadata contains additional response metadata
Metadata map[string]interface{} `json:"metadata,omitempty"`
// Timestamp is when the message was created
Timestamp time.Time `json:"timestamp,omitempty"`
}
// ErrorInfo contains error details
type ErrorInfo struct {
// Code is the error code
Code string `json:"code"`
// Message is a human-readable error message
Message string `json:"message"`
// Details contains additional error context
Details map[string]interface{} `json:"details,omitempty"`
}
// RequestMessage represents a client request
type RequestMessage struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Operation OperationType `json:"operation"`
Schema string `json:"schema,omitempty"`
Entity string `json:"entity"`
RecordID string `json:"record_id,omitempty"`
Data interface{} `json:"data,omitempty"`
Options *common.RequestOptions `json:"options,omitempty"`
}
// ResponseMessage represents a server response
type ResponseMessage struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error *ErrorInfo `json:"error,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// NotificationMessage represents a server-initiated notification
type NotificationMessage struct {
Type MessageType `json:"type"`
Operation OperationType `json:"operation"`
SubscriptionID string `json:"subscription_id"`
Schema string `json:"schema"`
Entity string `json:"entity"`
Data interface{} `json:"data"`
Timestamp time.Time `json:"timestamp"`
}
// SubscriptionMessage represents a subscription control message
type SubscriptionMessage struct {
ID string `json:"id"`
Type MessageType `json:"type"`
Operation OperationType `json:"operation"` // subscribe or unsubscribe
Schema string `json:"schema,omitempty"`
Entity string `json:"entity"`
Options *common.RequestOptions `json:"options,omitempty"` // Filters for subscription
SubscriptionID string `json:"subscription_id,omitempty"` // For unsubscribe
}
// NewRequestMessage creates a new request message
func NewRequestMessage(id string, operation OperationType, schema, entity string) *RequestMessage {
return &RequestMessage{
ID: id,
Type: MessageTypeRequest,
Operation: operation,
Schema: schema,
Entity: entity,
}
}
// NewResponseMessage creates a new response message
func NewResponseMessage(id string, success bool, data interface{}) *ResponseMessage {
return &ResponseMessage{
ID: id,
Type: MessageTypeResponse,
Success: success,
Data: data,
Timestamp: time.Now(),
}
}
// NewErrorResponse creates an error response message
func NewErrorResponse(id string, code, message string) *ResponseMessage {
return &ResponseMessage{
ID: id,
Type: MessageTypeResponse,
Success: false,
Error: &ErrorInfo{
Code: code,
Message: message,
},
Timestamp: time.Now(),
}
}
// NewNotificationMessage creates a new notification message
func NewNotificationMessage(subscriptionID string, operation OperationType, schema, entity string, data interface{}) *NotificationMessage {
return &NotificationMessage{
Type: MessageTypeNotification,
Operation: operation,
SubscriptionID: subscriptionID,
Schema: schema,
Entity: entity,
Data: data,
Timestamp: time.Now(),
}
}
// ParseMessage parses a JSON message into a Message struct
func ParseMessage(data []byte) (*Message, error) {
var msg Message
if err := json.Unmarshal(data, &msg); err != nil {
return nil, err
}
return &msg, nil
}
// ToJSON converts a message to JSON bytes
func (m *Message) ToJSON() ([]byte, error) {
return json.Marshal(m)
}
// ToJSON converts a response message to JSON bytes
func (r *ResponseMessage) ToJSON() ([]byte, error) {
return json.Marshal(r)
}
// ToJSON converts a notification message to JSON bytes
func (n *NotificationMessage) ToJSON() ([]byte, error) {
return json.Marshal(n)
}
// IsValid checks if a message is valid
func (m *Message) IsValid() bool {
// Type must be set
if m.Type == "" {
return false
}
// Request messages must have an ID, operation, and entity
if m.Type == MessageTypeRequest {
return m.ID != "" && m.Operation != "" && m.Entity != ""
}
// Subscription messages must have an ID and operation
if m.Type == MessageTypeSubscription {
return m.ID != "" && m.Operation != ""
}
return true
}

View File

@@ -0,0 +1,414 @@
package websocketspec
import (
"encoding/json"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMessageType_Constants(t *testing.T) {
assert.Equal(t, MessageType("request"), MessageTypeRequest)
assert.Equal(t, MessageType("response"), MessageTypeResponse)
assert.Equal(t, MessageType("notification"), MessageTypeNotification)
assert.Equal(t, MessageType("subscription"), MessageTypeSubscription)
assert.Equal(t, MessageType("error"), MessageTypeError)
assert.Equal(t, MessageType("ping"), MessageTypePing)
assert.Equal(t, MessageType("pong"), MessageTypePong)
}
func TestOperationType_Constants(t *testing.T) {
assert.Equal(t, OperationType("read"), OperationRead)
assert.Equal(t, OperationType("create"), OperationCreate)
assert.Equal(t, OperationType("update"), OperationUpdate)
assert.Equal(t, OperationType("delete"), OperationDelete)
assert.Equal(t, OperationType("subscribe"), OperationSubscribe)
assert.Equal(t, OperationType("unsubscribe"), OperationUnsubscribe)
assert.Equal(t, OperationType("meta"), OperationMeta)
}
func TestParseMessage_ValidRequestMessage(t *testing.T) {
jsonData := `{
"id": "msg-1",
"type": "request",
"operation": "read",
"schema": "public",
"entity": "users",
"record_id": "123",
"options": {
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
],
"limit": 10
}
}`
msg, err := ParseMessage([]byte(jsonData))
require.NoError(t, err)
assert.NotNil(t, msg)
assert.Equal(t, "msg-1", msg.ID)
assert.Equal(t, MessageTypeRequest, msg.Type)
assert.Equal(t, OperationRead, msg.Operation)
assert.Equal(t, "public", msg.Schema)
assert.Equal(t, "users", msg.Entity)
assert.Equal(t, "123", msg.RecordID)
assert.NotNil(t, msg.Options)
assert.Equal(t, 10, *msg.Options.Limit)
}
func TestParseMessage_ValidSubscriptionMessage(t *testing.T) {
jsonData := `{
"id": "sub-1",
"type": "subscription",
"operation": "subscribe",
"schema": "public",
"entity": "users"
}`
msg, err := ParseMessage([]byte(jsonData))
require.NoError(t, err)
assert.NotNil(t, msg)
assert.Equal(t, "sub-1", msg.ID)
assert.Equal(t, MessageTypeSubscription, msg.Type)
assert.Equal(t, OperationSubscribe, msg.Operation)
assert.Equal(t, "public", msg.Schema)
assert.Equal(t, "users", msg.Entity)
}
func TestParseMessage_InvalidJSON(t *testing.T) {
jsonData := `{invalid json}`
msg, err := ParseMessage([]byte(jsonData))
assert.Error(t, err)
assert.Nil(t, msg)
}
func TestParseMessage_EmptyData(t *testing.T) {
msg, err := ParseMessage([]byte{})
assert.Error(t, err)
assert.Nil(t, msg)
}
func TestMessage_IsValid_ValidRequestMessage(t *testing.T) {
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
Entity: "users",
}
assert.True(t, msg.IsValid())
}
func TestMessage_IsValid_InvalidRequestMessage_NoID(t *testing.T) {
msg := &Message{
Type: MessageTypeRequest,
Operation: OperationRead,
Entity: "users",
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_InvalidRequestMessage_NoOperation(t *testing.T) {
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Entity: "users",
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_InvalidRequestMessage_NoEntity(t *testing.T) {
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_ValidSubscriptionMessage(t *testing.T) {
msg := &Message{
ID: "sub-1",
Type: MessageTypeSubscription,
Operation: OperationSubscribe,
}
assert.True(t, msg.IsValid())
}
func TestMessage_IsValid_InvalidSubscriptionMessage_NoID(t *testing.T) {
msg := &Message{
Type: MessageTypeSubscription,
Operation: OperationSubscribe,
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_InvalidSubscriptionMessage_NoOperation(t *testing.T) {
msg := &Message{
ID: "sub-1",
Type: MessageTypeSubscription,
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_NoType(t *testing.T) {
msg := &Message{
ID: "msg-1",
}
assert.False(t, msg.IsValid())
}
func TestMessage_IsValid_ResponseMessage(t *testing.T) {
msg := &Message{
Type: MessageTypeResponse,
}
// Response messages don't require specific fields
assert.True(t, msg.IsValid())
}
func TestMessage_IsValid_NotificationMessage(t *testing.T) {
msg := &Message{
Type: MessageTypeNotification,
}
// Notification messages don't require specific fields
assert.True(t, msg.IsValid())
}
func TestMessage_ToJSON(t *testing.T) {
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
Entity: "users",
}
jsonData, err := msg.ToJSON()
require.NoError(t, err)
assert.NotEmpty(t, jsonData)
// Parse back to verify
var parsed map[string]interface{}
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
assert.Equal(t, "msg-1", parsed["id"])
assert.Equal(t, "request", parsed["type"])
assert.Equal(t, "read", parsed["operation"])
assert.Equal(t, "users", parsed["entity"])
}
func TestNewRequestMessage(t *testing.T) {
msg := NewRequestMessage("msg-1", OperationRead, "public", "users")
assert.Equal(t, "msg-1", msg.ID)
assert.Equal(t, MessageTypeRequest, msg.Type)
assert.Equal(t, OperationRead, msg.Operation)
assert.Equal(t, "public", msg.Schema)
assert.Equal(t, "users", msg.Entity)
}
func TestNewResponseMessage(t *testing.T) {
data := map[string]interface{}{"id": 1, "name": "John"}
msg := NewResponseMessage("msg-1", true, data)
assert.Equal(t, "msg-1", msg.ID)
assert.Equal(t, MessageTypeResponse, msg.Type)
assert.True(t, msg.Success)
assert.Equal(t, data, msg.Data)
assert.False(t, msg.Timestamp.IsZero())
}
func TestNewErrorResponse(t *testing.T) {
msg := NewErrorResponse("msg-1", "validation_error", "Email is required")
assert.Equal(t, "msg-1", msg.ID)
assert.Equal(t, MessageTypeResponse, msg.Type)
assert.False(t, msg.Success)
assert.Nil(t, msg.Data)
assert.NotNil(t, msg.Error)
assert.Equal(t, "validation_error", msg.Error.Code)
assert.Equal(t, "Email is required", msg.Error.Message)
assert.False(t, msg.Timestamp.IsZero())
}
func TestNewNotificationMessage(t *testing.T) {
data := map[string]interface{}{"id": 1, "name": "John"}
msg := NewNotificationMessage("sub-123", OperationCreate, "public", "users", data)
assert.Equal(t, MessageTypeNotification, msg.Type)
assert.Equal(t, OperationCreate, msg.Operation)
assert.Equal(t, "sub-123", msg.SubscriptionID)
assert.Equal(t, "public", msg.Schema)
assert.Equal(t, "users", msg.Entity)
assert.Equal(t, data, msg.Data)
assert.False(t, msg.Timestamp.IsZero())
}
func TestResponseMessage_ToJSON(t *testing.T) {
resp := NewResponseMessage("msg-1", true, map[string]interface{}{"test": "data"})
jsonData, err := resp.ToJSON()
require.NoError(t, err)
assert.NotEmpty(t, jsonData)
// Verify JSON structure
var parsed map[string]interface{}
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
assert.Equal(t, "msg-1", parsed["id"])
assert.Equal(t, "response", parsed["type"])
assert.True(t, parsed["success"].(bool))
}
func TestNotificationMessage_ToJSON(t *testing.T) {
notif := NewNotificationMessage("sub-123", OperationUpdate, "public", "users", map[string]interface{}{"id": 1})
jsonData, err := notif.ToJSON()
require.NoError(t, err)
assert.NotEmpty(t, jsonData)
// Verify JSON structure
var parsed map[string]interface{}
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
assert.Equal(t, "notification", parsed["type"])
assert.Equal(t, "update", parsed["operation"])
assert.Equal(t, "sub-123", parsed["subscription_id"])
}
func TestErrorInfo_Structure(t *testing.T) {
err := &ErrorInfo{
Code: "validation_error",
Message: "Invalid input",
Details: map[string]interface{}{
"field": "email",
"value": "invalid",
},
}
assert.Equal(t, "validation_error", err.Code)
assert.Equal(t, "Invalid input", err.Message)
assert.NotNil(t, err.Details)
assert.Equal(t, "email", err.Details["field"])
}
func TestMessage_WithOptions(t *testing.T) {
limit := 10
offset := 0
msg := &Message{
ID: "msg-1",
Type: MessageTypeRequest,
Operation: OperationRead,
Entity: "users",
Options: &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
Columns: []string{"id", "name", "email"},
Sort: []common.SortOption{
{Column: "name", Direction: "asc"},
},
Limit: &limit,
Offset: &offset,
},
}
assert.True(t, msg.IsValid())
assert.NotNil(t, msg.Options)
assert.Len(t, msg.Options.Filters, 1)
assert.Equal(t, "status", msg.Options.Filters[0].Column)
assert.Len(t, msg.Options.Columns, 3)
assert.Len(t, msg.Options.Sort, 1)
assert.Equal(t, 10, *msg.Options.Limit)
}
func TestMessage_CompleteRequestFlow(t *testing.T) {
// Create a request message
req := NewRequestMessage("msg-123", OperationCreate, "public", "users")
req.Data = map[string]interface{}{
"name": "John Doe",
"email": "john@example.com",
"status": "active",
}
// Convert to JSON
reqJSON, err := json.Marshal(req)
require.NoError(t, err)
// Parse back
parsed, err := ParseMessage(reqJSON)
require.NoError(t, err)
assert.True(t, parsed.IsValid())
assert.Equal(t, "msg-123", parsed.ID)
assert.Equal(t, MessageTypeRequest, parsed.Type)
assert.Equal(t, OperationCreate, parsed.Operation)
// Create success response
resp := NewResponseMessage("msg-123", true, map[string]interface{}{
"id": 1,
"name": "John Doe",
"email": "john@example.com",
"status": "active",
})
respJSON, err := resp.ToJSON()
require.NoError(t, err)
assert.NotEmpty(t, respJSON)
}
func TestMessage_TimestampSerialization(t *testing.T) {
now := time.Now()
msg := &Message{
ID: "msg-1",
Type: MessageTypeResponse,
Timestamp: now,
}
jsonData, err := msg.ToJSON()
require.NoError(t, err)
// Parse back
parsed, err := ParseMessage(jsonData)
require.NoError(t, err)
// Timestamps should be approximately equal (within a second due to serialization)
assert.WithinDuration(t, now, parsed.Timestamp, time.Second)
}
func TestMessage_WithMetadata(t *testing.T) {
msg := &Message{
ID: "msg-1",
Type: MessageTypeResponse,
Success: true,
Data: []interface{}{},
Metadata: map[string]interface{}{
"total": 100,
"count": 10,
"page": 1,
},
}
jsonData, err := msg.ToJSON()
require.NoError(t, err)
parsed, err := ParseMessage(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed.Metadata)
assert.Equal(t, float64(100), parsed.Metadata["total"]) // JSON numbers are float64
assert.Equal(t, float64(10), parsed.Metadata["count"])
}

View File

@@ -0,0 +1,192 @@
package websocketspec
import (
"sync"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Subscription represents a subscription to entity changes
type Subscription struct {
// ID is the unique subscription identifier
ID string
// ConnectionID is the ID of the connection that owns this subscription
ConnectionID string
// Schema is the database schema
Schema string
// Entity is the table/model name
Entity string
// Options contains filters and other query options
Options *common.RequestOptions
// Active indicates if the subscription is active
Active bool
}
// SubscriptionManager manages all subscriptions
type SubscriptionManager struct {
// subscriptions maps subscription ID to subscription
subscriptions map[string]*Subscription
// entitySubscriptions maps "schema.entity" to list of subscription IDs
entitySubscriptions map[string][]string
// mu protects the maps
mu sync.RWMutex
}
// NewSubscriptionManager creates a new subscription manager
func NewSubscriptionManager() *SubscriptionManager {
return &SubscriptionManager{
subscriptions: make(map[string]*Subscription),
entitySubscriptions: make(map[string][]string),
}
}
// Subscribe creates a new subscription
func (sm *SubscriptionManager) Subscribe(id, connID, schema, entity string, options *common.RequestOptions) *Subscription {
sm.mu.Lock()
defer sm.mu.Unlock()
sub := &Subscription{
ID: id,
ConnectionID: connID,
Schema: schema,
Entity: entity,
Options: options,
Active: true,
}
// Store subscription
sm.subscriptions[id] = sub
// Index by entity
key := makeEntityKey(schema, entity)
sm.entitySubscriptions[key] = append(sm.entitySubscriptions[key], id)
logger.Info("[WebSocketSpec] Subscription created: %s for %s.%s (conn: %s)", id, schema, entity, connID)
return sub
}
// Unsubscribe removes a subscription
func (sm *SubscriptionManager) Unsubscribe(subID string) bool {
sm.mu.Lock()
defer sm.mu.Unlock()
sub, exists := sm.subscriptions[subID]
if !exists {
return false
}
// Remove from entity index
key := makeEntityKey(sub.Schema, sub.Entity)
if subs, ok := sm.entitySubscriptions[key]; ok {
newSubs := make([]string, 0, len(subs)-1)
for _, id := range subs {
if id != subID {
newSubs = append(newSubs, id)
}
}
if len(newSubs) > 0 {
sm.entitySubscriptions[key] = newSubs
} else {
delete(sm.entitySubscriptions, key)
}
}
// Remove subscription
delete(sm.subscriptions, subID)
logger.Info("[WebSocketSpec] Subscription removed: %s", subID)
return true
}
// GetSubscription retrieves a subscription by ID
func (sm *SubscriptionManager) GetSubscription(subID string) (*Subscription, bool) {
sm.mu.RLock()
defer sm.mu.RUnlock()
sub, ok := sm.subscriptions[subID]
return sub, ok
}
// GetSubscriptionsByEntity retrieves all subscriptions for an entity
func (sm *SubscriptionManager) GetSubscriptionsByEntity(schema, entity string) []*Subscription {
sm.mu.RLock()
defer sm.mu.RUnlock()
key := makeEntityKey(schema, entity)
subIDs, ok := sm.entitySubscriptions[key]
if !ok {
return nil
}
result := make([]*Subscription, 0, len(subIDs))
for _, subID := range subIDs {
if sub, ok := sm.subscriptions[subID]; ok && sub.Active {
result = append(result, sub)
}
}
return result
}
// GetSubscriptionsByConnection retrieves all subscriptions for a connection
func (sm *SubscriptionManager) GetSubscriptionsByConnection(connID string) []*Subscription {
sm.mu.RLock()
defer sm.mu.RUnlock()
result := make([]*Subscription, 0)
for _, sub := range sm.subscriptions {
if sub.ConnectionID == connID && sub.Active {
result = append(result, sub)
}
}
return result
}
// Count returns the total number of active subscriptions
func (sm *SubscriptionManager) Count() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.subscriptions)
}
// CountForEntity returns the number of subscriptions for a specific entity
func (sm *SubscriptionManager) CountForEntity(schema, entity string) int {
sm.mu.RLock()
defer sm.mu.RUnlock()
key := makeEntityKey(schema, entity)
return len(sm.entitySubscriptions[key])
}
// MatchesFilters checks if data matches the subscription's filters
func (s *Subscription) MatchesFilters(data interface{}) bool {
// If no filters, match everything
if s.Options == nil || len(s.Options.Filters) == 0 {
return true
}
// TODO: Implement filter matching logic
// For now, return true (send all notifications)
// In a full implementation, you would:
// 1. Convert data to a map
// 2. Evaluate each filter against the data
// 3. Return true only if all filters match
return true
}
// makeEntityKey creates a key for entity indexing
func makeEntityKey(schema, entity string) string {
if schema == "" {
return entity
}
return schema + "." + entity
}

View File

@@ -0,0 +1,434 @@
package websocketspec
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewSubscriptionManager(t *testing.T) {
sm := NewSubscriptionManager()
assert.NotNil(t, sm)
assert.NotNil(t, sm.subscriptions)
assert.NotNil(t, sm.entitySubscriptions)
assert.Equal(t, 0, sm.Count())
}
func TestSubscriptionManager_Subscribe(t *testing.T) {
sm := NewSubscriptionManager()
// Create a subscription
sub := sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
assert.NotNil(t, sub)
assert.Equal(t, "sub-1", sub.ID)
assert.Equal(t, "conn-1", sub.ConnectionID)
assert.Equal(t, "public", sub.Schema)
assert.Equal(t, "users", sub.Entity)
assert.True(t, sub.Active)
assert.Equal(t, 1, sm.Count())
}
func TestSubscriptionManager_Subscribe_WithOptions(t *testing.T) {
sm := NewSubscriptionManager()
options := &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
}
sub := sm.Subscribe("sub-1", "conn-1", "public", "users", options)
assert.NotNil(t, sub)
assert.NotNil(t, sub.Options)
assert.Len(t, sub.Options.Filters, 1)
assert.Equal(t, "status", sub.Options.Filters[0].Column)
}
func TestSubscriptionManager_Subscribe_MultipleSubscriptions(t *testing.T) {
sm := NewSubscriptionManager()
sub1 := sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sub2 := sm.Subscribe("sub-2", "conn-1", "public", "posts", nil)
sub3 := sm.Subscribe("sub-3", "conn-2", "public", "users", nil)
assert.NotNil(t, sub1)
assert.NotNil(t, sub2)
assert.NotNil(t, sub3)
assert.Equal(t, 3, sm.Count())
}
func TestSubscriptionManager_Unsubscribe(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
assert.Equal(t, 1, sm.Count())
// Unsubscribe
ok := sm.Unsubscribe("sub-1")
assert.True(t, ok)
assert.Equal(t, 0, sm.Count())
}
func TestSubscriptionManager_Unsubscribe_NonExistent(t *testing.T) {
sm := NewSubscriptionManager()
ok := sm.Unsubscribe("non-existent")
assert.False(t, ok)
}
func TestSubscriptionManager_Unsubscribe_MultipleSubscriptions(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sm.Subscribe("sub-2", "conn-1", "public", "posts", nil)
sm.Subscribe("sub-3", "conn-2", "public", "users", nil)
assert.Equal(t, 3, sm.Count())
// Unsubscribe one
ok := sm.Unsubscribe("sub-2")
assert.True(t, ok)
assert.Equal(t, 2, sm.Count())
// Verify the right subscription was removed
_, exists := sm.GetSubscription("sub-2")
assert.False(t, exists)
_, exists = sm.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = sm.GetSubscription("sub-3")
assert.True(t, exists)
}
func TestSubscriptionManager_GetSubscription(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
// Get existing subscription
sub, exists := sm.GetSubscription("sub-1")
assert.True(t, exists)
assert.NotNil(t, sub)
assert.Equal(t, "sub-1", sub.ID)
}
func TestSubscriptionManager_GetSubscription_NonExistent(t *testing.T) {
sm := NewSubscriptionManager()
sub, exists := sm.GetSubscription("non-existent")
assert.False(t, exists)
assert.Nil(t, sub)
}
func TestSubscriptionManager_GetSubscriptionsByEntity(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sm.Subscribe("sub-2", "conn-2", "public", "users", nil)
sm.Subscribe("sub-3", "conn-1", "public", "posts", nil)
// Get subscriptions for users entity
subs := sm.GetSubscriptionsByEntity("public", "users")
assert.Len(t, subs, 2)
// Verify subscription IDs
ids := make([]string, len(subs))
for i, sub := range subs {
ids[i] = sub.ID
}
assert.Contains(t, ids, "sub-1")
assert.Contains(t, ids, "sub-2")
}
func TestSubscriptionManager_GetSubscriptionsByEntity_NoSchema(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "", "users", nil)
sm.Subscribe("sub-2", "conn-2", "", "users", nil)
// Get subscriptions without schema
subs := sm.GetSubscriptionsByEntity("", "users")
assert.Len(t, subs, 2)
}
func TestSubscriptionManager_GetSubscriptionsByEntity_NoResults(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
// Get subscriptions for non-existent entity
subs := sm.GetSubscriptionsByEntity("public", "posts")
assert.Nil(t, subs)
}
func TestSubscriptionManager_GetSubscriptionsByConnection(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sm.Subscribe("sub-2", "conn-1", "public", "posts", nil)
sm.Subscribe("sub-3", "conn-2", "public", "users", nil)
// Get subscriptions for connection 1
subs := sm.GetSubscriptionsByConnection("conn-1")
assert.Len(t, subs, 2)
// Verify subscription IDs
ids := make([]string, len(subs))
for i, sub := range subs {
ids[i] = sub.ID
}
assert.Contains(t, ids, "sub-1")
assert.Contains(t, ids, "sub-2")
}
func TestSubscriptionManager_GetSubscriptionsByConnection_NoResults(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
// Get subscriptions for non-existent connection
subs := sm.GetSubscriptionsByConnection("conn-2")
assert.Empty(t, subs)
}
func TestSubscriptionManager_Count(t *testing.T) {
sm := NewSubscriptionManager()
assert.Equal(t, 0, sm.Count())
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
assert.Equal(t, 1, sm.Count())
sm.Subscribe("sub-2", "conn-1", "public", "posts", nil)
assert.Equal(t, 2, sm.Count())
sm.Unsubscribe("sub-1")
assert.Equal(t, 1, sm.Count())
sm.Unsubscribe("sub-2")
assert.Equal(t, 0, sm.Count())
}
func TestSubscriptionManager_CountForEntity(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sm.Subscribe("sub-2", "conn-2", "public", "users", nil)
sm.Subscribe("sub-3", "conn-1", "public", "posts", nil)
assert.Equal(t, 2, sm.CountForEntity("public", "users"))
assert.Equal(t, 1, sm.CountForEntity("public", "posts"))
assert.Equal(t, 0, sm.CountForEntity("public", "orders"))
}
func TestSubscriptionManager_UnsubscribeUpdatesEntityIndex(t *testing.T) {
sm := NewSubscriptionManager()
sm.Subscribe("sub-1", "conn-1", "public", "users", nil)
sm.Subscribe("sub-2", "conn-2", "public", "users", nil)
assert.Equal(t, 2, sm.CountForEntity("public", "users"))
// Unsubscribe one
sm.Unsubscribe("sub-1")
assert.Equal(t, 1, sm.CountForEntity("public", "users"))
// Unsubscribe the other
sm.Unsubscribe("sub-2")
assert.Equal(t, 0, sm.CountForEntity("public", "users"))
}
func TestSubscription_MatchesFilters_NoFilters(t *testing.T) {
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Schema: "public",
Entity: "users",
Options: nil,
Active: true,
}
data := map[string]interface{}{
"id": 1,
"name": "John",
"status": "active",
}
// Should match when no filters are specified
assert.True(t, sub.MatchesFilters(data))
}
func TestSubscription_MatchesFilters_WithFilters(t *testing.T) {
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
},
Active: true,
}
data := map[string]interface{}{
"id": 1,
"name": "John",
"status": "active",
}
// Current implementation returns true for all data
// This test documents the expected behavior
assert.True(t, sub.MatchesFilters(data))
}
func TestSubscription_MatchesFilters_EmptyFiltersArray(t *testing.T) {
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Schema: "public",
Entity: "users",
Options: &common.RequestOptions{
Filters: []common.FilterOption{},
},
Active: true,
}
data := map[string]interface{}{
"id": 1,
"name": "John",
}
// Should match when filters array is empty
assert.True(t, sub.MatchesFilters(data))
}
func TestMakeEntityKey(t *testing.T) {
tests := []struct {
name string
schema string
entity string
expected string
}{
{
name: "With schema",
schema: "public",
entity: "users",
expected: "public.users",
},
{
name: "Without schema",
schema: "",
entity: "users",
expected: "users",
},
{
name: "Different schema",
schema: "custom",
entity: "posts",
expected: "custom.posts",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := makeEntityKey(tt.schema, tt.entity)
assert.Equal(t, tt.expected, result)
})
}
}
func TestSubscriptionManager_ConcurrentOperations(t *testing.T) {
sm := NewSubscriptionManager()
// This test verifies that concurrent operations don't cause race conditions
// Run with: go test -race
done := make(chan bool)
// Goroutine 1: Subscribe
go func() {
for i := 0; i < 100; i++ {
sm.Subscribe("sub-"+string(rune(i)), "conn-1", "public", "users", nil)
}
done <- true
}()
// Goroutine 2: Get subscriptions
go func() {
for i := 0; i < 100; i++ {
sm.GetSubscriptionsByEntity("public", "users")
}
done <- true
}()
// Goroutine 3: Count
go func() {
for i := 0; i < 100; i++ {
sm.Count()
}
done <- true
}()
// Wait for all goroutines
<-done
<-done
<-done
}
func TestSubscriptionManager_CompleteLifecycle(t *testing.T) {
sm := NewSubscriptionManager()
// Create subscriptions
options := &common.RequestOptions{
Filters: []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
},
}
sub1 := sm.Subscribe("sub-1", "conn-1", "public", "users", options)
require.NotNil(t, sub1)
assert.Equal(t, 1, sm.Count())
sub2 := sm.Subscribe("sub-2", "conn-1", "public", "posts", nil)
require.NotNil(t, sub2)
assert.Equal(t, 2, sm.Count())
// Get by entity
userSubs := sm.GetSubscriptionsByEntity("public", "users")
assert.Len(t, userSubs, 1)
assert.Equal(t, "sub-1", userSubs[0].ID)
// Get by connection
connSubs := sm.GetSubscriptionsByConnection("conn-1")
assert.Len(t, connSubs, 2)
// Get specific subscription
sub, exists := sm.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", sub.ID)
assert.NotNil(t, sub.Options)
// Count by entity
assert.Equal(t, 1, sm.CountForEntity("public", "users"))
assert.Equal(t, 1, sm.CountForEntity("public", "posts"))
// Unsubscribe
ok := sm.Unsubscribe("sub-1")
assert.True(t, ok)
assert.Equal(t, 1, sm.Count())
assert.Equal(t, 0, sm.CountForEntity("public", "users"))
// Verify subscription is gone
_, exists = sm.GetSubscription("sub-1")
assert.False(t, exists)
// Unsubscribe second subscription
ok = sm.Unsubscribe("sub-2")
assert.True(t, ok)
assert.Equal(t, 0, sm.Count())
}

View File

@@ -0,0 +1,332 @@
// Package websocketspec provides a WebSocket-based API specification for real-time
// CRUD operations with bidirectional communication and subscription support.
//
// # Key Features
//
// - Real-time bidirectional communication over WebSocket
// - CRUD operations (Create, Read, Update, Delete)
// - Real-time subscriptions with filtering
// - Lifecycle hooks for all operations
// - Database-agnostic: Works with GORM and Bun ORM through adapters
// - Automatic change notifications to subscribers
// - Connection and subscription management
//
// # Message Protocol
//
// WebSocketSpec uses JSON messages for communication:
//
// {
// "id": "unique-message-id",
// "type": "request|response|notification|subscription",
// "operation": "read|create|update|delete|subscribe|unsubscribe",
// "schema": "public",
// "entity": "users",
// "data": {...},
// "options": {
// "filters": [...],
// "columns": [...],
// "preload": [...],
// "sort": [...],
// "limit": 10
// }
// }
//
// # Usage Example
//
// // Create handler with GORM
// handler := websocketspec.NewHandlerWithGORM(db)
//
// // Register models
// handler.Registry.RegisterModel("public.users", &User{})
//
// // Setup WebSocket endpoint
// http.HandleFunc("/ws", handler.HandleWebSocket)
//
// // Start server
// http.ListenAndServe(":8080", nil)
//
// # Client Example
//
// // Connect to WebSocket
// ws := new WebSocket("ws://localhost:8080/ws")
//
// // Send read request
// ws.send(JSON.stringify({
// id: "msg-1",
// type: "request",
// operation: "read",
// entity: "users",
// options: {
// filters: [{column: "status", operator: "eq", value: "active"}],
// limit: 10
// }
// }))
//
// // Subscribe to changes
// ws.send(JSON.stringify({
// id: "msg-2",
// type: "subscription",
// operation: "subscribe",
// entity: "users",
// options: {
// filters: [{column: "status", operator: "eq", value: "active"}]
// }
// }))
package websocketspec
import (
"github.com/uptrace/bun"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// NewHandlerWithGORM creates a new Handler with GORM adapter
func NewHandlerWithGORM(db *gorm.DB) *Handler {
gormAdapter := database.NewGormAdapter(db)
registry := modelregistry.NewModelRegistry()
return NewHandler(gormAdapter, registry)
}
// NewHandlerWithBun creates a new Handler with Bun adapter
func NewHandlerWithBun(db *bun.DB) *Handler {
bunAdapter := database.NewBunAdapter(db)
registry := modelregistry.NewModelRegistry()
return NewHandler(bunAdapter, registry)
}
// NewHandlerWithDatabase creates a new Handler with a custom database adapter
func NewHandlerWithDatabase(db common.Database, registry common.ModelRegistry) *Handler {
return NewHandler(db, registry)
}
// Example usage functions for documentation:
// ExampleWithGORM shows how to use WebSocketSpec with GORM
func ExampleWithGORM(db *gorm.DB) {
// Create handler using GORM
handler := NewHandlerWithGORM(db)
// Register models
handler.Registry().RegisterModel("public.users", &struct{}{})
// Register hooks (optional)
handler.Hooks().RegisterBefore(OperationRead, func(ctx *HookContext) error {
// Add custom logic before read operations
return nil
})
// Setup WebSocket endpoint
// http.HandleFunc("/ws", handler.HandleWebSocket)
// Start server
// http.ListenAndServe(":8080", nil)
}
// ExampleWithBun shows how to use WebSocketSpec with Bun ORM
func ExampleWithBun(bunDB *bun.DB) {
// Create handler using Bun
handler := NewHandlerWithBun(bunDB)
// Register models
handler.Registry().RegisterModel("public.users", &struct{}{})
// Setup WebSocket endpoint
// http.HandleFunc("/ws", handler.HandleWebSocket)
}
// ExampleWithHooks shows how to use lifecycle hooks
func ExampleWithHooks(db *gorm.DB) {
handler := NewHandlerWithGORM(db)
// Register a before-read hook for authorization
handler.Hooks().RegisterBefore(OperationRead, func(ctx *HookContext) error {
// Check if user has permission to read this entity
// return fmt.Errorf("unauthorized") if not allowed
return nil
})
// Register an after-create hook for logging
handler.Hooks().RegisterAfter(OperationCreate, func(ctx *HookContext) error {
// Log the created record
// logger.Info("Created record: %v", ctx.Result)
return nil
})
// Register a before-subscribe hook to limit subscriptions
handler.Hooks().Register(BeforeSubscribe, func(ctx *HookContext) error {
// Limit number of subscriptions per connection
// if len(ctx.Connection.subscriptions) >= 10 {
// return fmt.Errorf("maximum subscriptions reached")
// }
return nil
})
}
// ExampleWithSubscriptions shows subscription usage
func ExampleWithSubscriptions() {
// Client-side JavaScript example:
/*
const ws = new WebSocket("ws://localhost:8080/ws");
// Subscribe to user changes
ws.send(JSON.stringify({
id: "sub-1",
type: "subscription",
operation: "subscribe",
schema: "public",
entity: "users",
options: {
filters: [
{column: "status", operator: "eq", value: "active"}
]
}
}));
// Handle notifications
ws.onmessage = (event) => {
const msg = JSON.parse(event.data);
if (msg.type === "notification") {
console.log("User changed:", msg.data);
console.log("Operation:", msg.operation); // create, update, or delete
}
};
// Unsubscribe
ws.send(JSON.stringify({
id: "unsub-1",
type: "subscription",
operation: "unsubscribe",
subscription_id: "sub-abc123"
}));
*/
}
// ExampleCRUDOperations shows basic CRUD operations
func ExampleCRUDOperations() {
// Client-side JavaScript example:
/*
const ws = new WebSocket("ws://localhost:8080/ws");
// CREATE - Create a new user
ws.send(JSON.stringify({
id: "create-1",
type: "request",
operation: "create",
schema: "public",
entity: "users",
data: {
name: "John Doe",
email: "john@example.com",
status: "active"
}
}));
// READ - Get all active users
ws.send(JSON.stringify({
id: "read-1",
type: "request",
operation: "read",
schema: "public",
entity: "users",
options: {
filters: [{column: "status", operator: "eq", value: "active"}],
columns: ["id", "name", "email"],
sort: [{column: "name", direction: "asc"}],
limit: 10
}
}));
// READ BY ID - Get a specific user
ws.send(JSON.stringify({
id: "read-2",
type: "request",
operation: "read",
schema: "public",
entity: "users",
record_id: "123"
}));
// UPDATE - Update a user
ws.send(JSON.stringify({
id: "update-1",
type: "request",
operation: "update",
schema: "public",
entity: "users",
record_id: "123",
data: {
name: "John Updated",
email: "john.updated@example.com"
}
}));
// DELETE - Delete a user
ws.send(JSON.stringify({
id: "delete-1",
type: "request",
operation: "delete",
schema: "public",
entity: "users",
record_id: "123"
}));
// Handle responses
ws.onmessage = (event) => {
const response = JSON.parse(event.data);
if (response.type === "response") {
if (response.success) {
console.log("Operation successful:", response.data);
} else {
console.error("Operation failed:", response.error);
}
}
};
*/
}
// ExampleAuthentication shows how to implement authentication
func ExampleAuthentication() {
// Server-side example with authentication hook:
/*
handler := NewHandlerWithGORM(db)
// Register before-connect hook for authentication
handler.Hooks().Register(BeforeConnect, func(ctx *HookContext) error {
// Extract token from query params or headers
r := ctx.Connection.ws.UnderlyingConn().RemoteAddr()
// Validate token
// token := extractToken(r)
// user, err := validateToken(token)
// if err != nil {
// return fmt.Errorf("authentication failed: %w", err)
// }
// Store user info in connection metadata
// ctx.Connection.SetMetadata("user", user)
// ctx.Connection.SetMetadata("user_id", user.ID)
return nil
})
// Use connection metadata in other hooks
handler.Hooks().RegisterBefore(OperationRead, func(ctx *HookContext) error {
// Get user from connection metadata
// userID, _ := ctx.Connection.GetMetadata("user_id")
// Add filter to only show user's own records
// if ctx.Entity == "orders" {
// ctx.Options.Filters = append(ctx.Options.Filters, common.FilterOption{
// Column: "user_id",
// Operator: "eq",
// Value: userID,
// })
// }
return nil
})
*/
}

530
resolvespec-js/WEBSOCKET.md Normal file
View File

@@ -0,0 +1,530 @@
# WebSocketSpec JavaScript Client
A TypeScript/JavaScript client for connecting to WebSocketSpec servers with full support for real-time subscriptions, CRUD operations, and automatic reconnection.
## Installation
```bash
npm install @warkypublic/resolvespec-js
# or
yarn add @warkypublic/resolvespec-js
# or
pnpm add @warkypublic/resolvespec-js
```
## Quick Start
```typescript
import { WebSocketClient } from '@warkypublic/resolvespec-js';
// Create client
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
reconnect: true,
debug: true
});
// Connect
await client.connect();
// Read records
const users = await client.read('users', {
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
],
limit: 10
});
// Subscribe to changes
const subscriptionId = await client.subscribe('users', (notification) => {
console.log('User changed:', notification.operation, notification.data);
}, { schema: 'public' });
// Clean up
await client.unsubscribe(subscriptionId);
client.disconnect();
```
## Features
- **Real-Time Updates**: Subscribe to entity changes and receive instant notifications
- **Full CRUD Support**: Create, read, update, and delete operations
- **TypeScript Support**: Full type definitions included
- **Auto Reconnection**: Automatic reconnection with configurable retry logic
- **Heartbeat**: Built-in keepalive mechanism
- **Event System**: Listen to connection, error, and message events
- **Promise-based API**: All async operations return promises
- **Filter & Sort**: Advanced querying with filters, sorting, and pagination
- **Preloading**: Load related entities in a single query
## Configuration
```typescript
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws', // WebSocket server URL
reconnect: true, // Enable auto-reconnection
reconnectInterval: 3000, // Reconnection delay (ms)
maxReconnectAttempts: 10, // Max reconnection attempts
heartbeatInterval: 30000, // Heartbeat interval (ms)
debug: false // Enable debug logging
});
```
## API Reference
### Connection Management
#### `connect(): Promise<void>`
Connect to the WebSocket server.
```typescript
await client.connect();
```
#### `disconnect(): void`
Disconnect from the server.
```typescript
client.disconnect();
```
#### `isConnected(): boolean`
Check if currently connected.
```typescript
if (client.isConnected()) {
console.log('Connected!');
}
```
#### `getState(): ConnectionState`
Get current connection state: `'connecting'`, `'connected'`, `'disconnecting'`, `'disconnected'`, or `'reconnecting'`.
```typescript
const state = client.getState();
console.log('State:', state);
```
### CRUD Operations
#### `read<T>(entity: string, options?): Promise<T>`
Read records from an entity.
```typescript
// Read all active users
const users = await client.read('users', {
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
],
columns: ['id', 'name', 'email'],
sort: [
{ column: 'name', direction: 'asc' }
],
limit: 10,
offset: 0
});
// Read single record by ID
const user = await client.read('users', {
schema: 'public',
record_id: '123'
});
// Read with preloading
const posts = await client.read('posts', {
schema: 'public',
preload: [
{
relation: 'user',
columns: ['id', 'name', 'email']
},
{
relation: 'comments',
filters: [
{ column: 'status', operator: 'eq', value: 'approved' }
]
}
]
});
```
#### `create<T>(entity: string, data: any, options?): Promise<T>`
Create a new record.
```typescript
const newUser = await client.create('users', {
name: 'John Doe',
email: 'john@example.com',
status: 'active'
}, {
schema: 'public'
});
```
#### `update<T>(entity: string, id: string, data: any, options?): Promise<T>`
Update an existing record.
```typescript
const updatedUser = await client.update('users', '123', {
name: 'John Updated',
email: 'john.new@example.com'
}, {
schema: 'public'
});
```
#### `delete(entity: string, id: string, options?): Promise<void>`
Delete a record.
```typescript
await client.delete('users', '123', {
schema: 'public'
});
```
#### `meta<T>(entity: string, options?): Promise<T>`
Get metadata for an entity.
```typescript
const metadata = await client.meta('users', {
schema: 'public'
});
console.log('Columns:', metadata.columns);
console.log('Primary key:', metadata.primary_key);
```
### Subscriptions
#### `subscribe(entity: string, callback: Function, options?): Promise<string>`
Subscribe to entity changes.
```typescript
const subscriptionId = await client.subscribe(
'users',
(notification) => {
console.log('Operation:', notification.operation); // 'create', 'update', or 'delete'
console.log('Data:', notification.data);
console.log('Timestamp:', notification.timestamp);
},
{
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
]
}
);
```
#### `unsubscribe(subscriptionId: string): Promise<void>`
Unsubscribe from entity changes.
```typescript
await client.unsubscribe(subscriptionId);
```
#### `getSubscriptions(): Subscription[]`
Get list of active subscriptions.
```typescript
const subscriptions = client.getSubscriptions();
console.log('Active subscriptions:', subscriptions.length);
```
### Event Handling
#### `on(event: string, callback: Function): void`
Add event listener.
```typescript
// Connection events
client.on('connect', () => {
console.log('Connected!');
});
client.on('disconnect', (event) => {
console.log('Disconnected:', event.code, event.reason);
});
client.on('error', (error) => {
console.error('Error:', error);
});
// State changes
client.on('stateChange', (state) => {
console.log('State:', state);
});
// All messages
client.on('message', (message) => {
console.log('Message:', message);
});
```
#### `off(event: string): void`
Remove event listener.
```typescript
client.off('connect');
```
## Filter Operators
- `eq` - Equal (=)
- `neq` - Not Equal (!=)
- `gt` - Greater Than (>)
- `gte` - Greater Than or Equal (>=)
- `lt` - Less Than (<)
- `lte` - Less Than or Equal (<=)
- `like` - LIKE (case-sensitive)
- `ilike` - ILIKE (case-insensitive)
- `in` - IN (array of values)
## Examples
### Basic CRUD
```typescript
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
await client.connect();
// Create
const user = await client.create('users', {
name: 'Alice',
email: 'alice@example.com'
});
// Read
const users = await client.read('users', {
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
});
// Update
await client.update('users', user.id, { name: 'Alice Updated' });
// Delete
await client.delete('users', user.id);
client.disconnect();
```
### Real-Time Subscriptions
```typescript
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
await client.connect();
// Subscribe to all user changes
const subId = await client.subscribe('users', (notification) => {
switch (notification.operation) {
case 'create':
console.log('New user:', notification.data);
break;
case 'update':
console.log('User updated:', notification.data);
break;
case 'delete':
console.log('User deleted:', notification.data);
break;
}
});
// Later: unsubscribe
await client.unsubscribe(subId);
```
### React Integration
```typescript
import { useEffect, useState } from 'react';
import { WebSocketClient } from '@warkypublic/resolvespec-js';
function useWebSocket(url: string) {
const [client] = useState(() => new WebSocketClient({ url }));
const [isConnected, setIsConnected] = useState(false);
useEffect(() => {
client.on('connect', () => setIsConnected(true));
client.on('disconnect', () => setIsConnected(false));
client.connect();
return () => client.disconnect();
}, [client]);
return { client, isConnected };
}
function UsersComponent() {
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
const [users, setUsers] = useState([]);
useEffect(() => {
if (!isConnected) return;
const loadUsers = async () => {
// Subscribe to changes
await client.subscribe('users', (notification) => {
if (notification.operation === 'create') {
setUsers(prev => [...prev, notification.data]);
} else if (notification.operation === 'update') {
setUsers(prev => prev.map(u =>
u.id === notification.data.id ? notification.data : u
));
} else if (notification.operation === 'delete') {
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
}
});
// Load initial data
const data = await client.read('users');
setUsers(data);
};
loadUsers();
}, [client, isConnected]);
return (
<div>
<h2>Users {isConnected ? '🟢' : '🔴'}</h2>
{users.map(user => (
<div key={user.id}>{user.name}</div>
))}
</div>
);
}
```
### TypeScript with Typed Models
```typescript
interface User {
id: number;
name: string;
email: string;
status: 'active' | 'inactive';
}
interface Post {
id: number;
title: string;
content: string;
user_id: number;
user?: User;
}
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
await client.connect();
// Type-safe operations
const users = await client.read<User[]>('users', {
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
});
const newUser = await client.create<User>('users', {
name: 'Bob',
email: 'bob@example.com',
status: 'active'
});
// Type-safe subscriptions
await client.subscribe(
'posts',
(notification) => {
const post = notification.data as Post;
console.log('Post:', post.title);
}
);
```
### Error Handling
```typescript
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
reconnect: true,
maxReconnectAttempts: 5
});
client.on('error', (error) => {
console.error('Connection error:', error);
});
client.on('stateChange', (state) => {
console.log('State:', state);
if (state === 'reconnecting') {
console.log('Attempting to reconnect...');
}
});
try {
await client.connect();
try {
const user = await client.read('users', { record_id: '999' });
} catch (error) {
console.error('Record not found:', error);
}
try {
await client.create('users', { /* invalid data */ });
} catch (error) {
console.error('Validation failed:', error);
}
} catch (error) {
console.error('Connection failed:', error);
}
```
### Multiple Subscriptions
```typescript
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
await client.connect();
// Subscribe to multiple entities
const userSub = await client.subscribe('users', (n) => {
console.log('[Users]', n.operation, n.data);
});
const postSub = await client.subscribe('posts', (n) => {
console.log('[Posts]', n.operation, n.data);
}, {
filters: [{ column: 'status', operator: 'eq', value: 'published' }]
});
const commentSub = await client.subscribe('comments', (n) => {
console.log('[Comments]', n.operation, n.data);
});
// Check active subscriptions
console.log('Active:', client.getSubscriptions().length);
// Clean up
await client.unsubscribe(userSub);
await client.unsubscribe(postSub);
await client.unsubscribe(commentSub);
```
## Best Practices
1. **Always Clean Up**: Call `disconnect()` when done to close the connection properly
2. **Use TypeScript**: Leverage type definitions for better type safety
3. **Handle Errors**: Always wrap operations in try-catch blocks
4. **Limit Subscriptions**: Don't create too many subscriptions per connection
5. **Use Filters**: Apply filters to subscriptions to reduce unnecessary notifications
6. **Connection State**: Check `isConnected()` before operations
7. **Event Listeners**: Remove event listeners when no longer needed with `off()`
8. **Reconnection**: Enable auto-reconnection for production apps
## Browser Support
- Chrome/Edge 88+
- Firefox 85+
- Safari 14+
- Node.js 14.16+
## License
MIT

View File

@@ -0,0 +1,7 @@
// Types
export * from './types';
export * from './websocket-types';
// WebSocket Client
export { WebSocketClient } from './websocket-client';
export type { WebSocketClient as default } from './websocket-client';

View File

@@ -0,0 +1,487 @@
import { v4 as uuidv4 } from 'uuid';
import type {
WebSocketClientConfig,
WSMessage,
WSRequestMessage,
WSResponseMessage,
WSNotificationMessage,
WSOperation,
WSOptions,
Subscription,
SubscriptionOptions,
ConnectionState,
WebSocketClientEvents
} from './websocket-types';
export class WebSocketClient {
private ws: WebSocket | null = null;
private config: Required<WebSocketClientConfig>;
private messageHandlers: Map<string, (message: WSResponseMessage) => void> = new Map();
private subscriptions: Map<string, Subscription> = new Map();
private eventListeners: Partial<WebSocketClientEvents> = {};
private state: ConnectionState = 'disconnected';
private reconnectAttempts = 0;
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
private heartbeatTimer: ReturnType<typeof setInterval> | null = null;
private isManualClose = false;
constructor(config: WebSocketClientConfig) {
this.config = {
url: config.url,
reconnect: config.reconnect ?? true,
reconnectInterval: config.reconnectInterval ?? 3000,
maxReconnectAttempts: config.maxReconnectAttempts ?? 10,
heartbeatInterval: config.heartbeatInterval ?? 30000,
debug: config.debug ?? false
};
}
/**
* Connect to WebSocket server
*/
async connect(): Promise<void> {
if (this.ws?.readyState === WebSocket.OPEN) {
this.log('Already connected');
return;
}
this.isManualClose = false;
this.setState('connecting');
return new Promise((resolve, reject) => {
try {
this.ws = new WebSocket(this.config.url);
this.ws.onopen = () => {
this.log('Connected to WebSocket server');
this.setState('connected');
this.reconnectAttempts = 0;
this.startHeartbeat();
this.emit('connect');
resolve();
};
this.ws.onmessage = (event) => {
this.handleMessage(event.data);
};
this.ws.onerror = (event) => {
this.log('WebSocket error:', event);
const error = new Error('WebSocket connection error');
this.emit('error', error);
reject(error);
};
this.ws.onclose = (event) => {
this.log('WebSocket closed:', event.code, event.reason);
this.stopHeartbeat();
this.setState('disconnected');
this.emit('disconnect', event);
// Attempt reconnection if enabled and not manually closed
if (this.config.reconnect && !this.isManualClose && this.reconnectAttempts < this.config.maxReconnectAttempts) {
this.reconnectAttempts++;
this.log(`Reconnection attempt ${this.reconnectAttempts}/${this.config.maxReconnectAttempts}`);
this.setState('reconnecting');
this.reconnectTimer = setTimeout(() => {
this.connect().catch((err) => {
this.log('Reconnection failed:', err);
});
}, this.config.reconnectInterval);
}
};
} catch (error) {
reject(error);
}
});
}
/**
* Disconnect from WebSocket server
*/
disconnect(): void {
this.isManualClose = true;
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);
this.reconnectTimer = null;
}
this.stopHeartbeat();
if (this.ws) {
this.setState('disconnecting');
this.ws.close();
this.ws = null;
}
this.setState('disconnected');
this.messageHandlers.clear();
}
/**
* Send a CRUD request and wait for response
*/
async request<T = any>(
operation: WSOperation,
entity: string,
options?: {
schema?: string;
record_id?: string;
data?: any;
options?: WSOptions;
}
): Promise<T> {
this.ensureConnected();
const id = uuidv4();
const message: WSRequestMessage = {
id,
type: 'request',
operation,
entity,
schema: options?.schema,
record_id: options?.record_id,
data: options?.data,
options: options?.options
};
return new Promise((resolve, reject) => {
// Set up response handler
this.messageHandlers.set(id, (response: WSResponseMessage) => {
if (response.success) {
resolve(response.data);
} else {
reject(new Error(response.error?.message || 'Request failed'));
}
});
// Send message
this.send(message);
// Timeout after 30 seconds
setTimeout(() => {
if (this.messageHandlers.has(id)) {
this.messageHandlers.delete(id);
reject(new Error('Request timeout'));
}
}, 30000);
});
}
/**
* Read records
*/
async read<T = any>(entity: string, options?: {
schema?: string;
record_id?: string;
filters?: import('./types').FilterOption[];
columns?: string[];
sort?: import('./types').SortOption[];
preload?: import('./types').PreloadOption[];
limit?: number;
offset?: number;
}): Promise<T> {
return this.request<T>('read', entity, {
schema: options?.schema,
record_id: options?.record_id,
options: {
filters: options?.filters,
columns: options?.columns,
sort: options?.sort,
preload: options?.preload,
limit: options?.limit,
offset: options?.offset
}
});
}
/**
* Create a record
*/
async create<T = any>(entity: string, data: any, options?: {
schema?: string;
}): Promise<T> {
return this.request<T>('create', entity, {
schema: options?.schema,
data
});
}
/**
* Update a record
*/
async update<T = any>(entity: string, id: string, data: any, options?: {
schema?: string;
}): Promise<T> {
return this.request<T>('update', entity, {
schema: options?.schema,
record_id: id,
data
});
}
/**
* Delete a record
*/
async delete(entity: string, id: string, options?: {
schema?: string;
}): Promise<void> {
await this.request('delete', entity, {
schema: options?.schema,
record_id: id
});
}
/**
* Get metadata for an entity
*/
async meta<T = any>(entity: string, options?: {
schema?: string;
}): Promise<T> {
return this.request<T>('meta', entity, {
schema: options?.schema
});
}
/**
* Subscribe to entity changes
*/
async subscribe(
entity: string,
callback: (notification: WSNotificationMessage) => void,
options?: {
schema?: string;
filters?: import('./types').FilterOption[];
}
): Promise<string> {
this.ensureConnected();
const id = uuidv4();
const message: WSMessage = {
id,
type: 'subscription',
operation: 'subscribe',
entity,
schema: options?.schema,
options: {
filters: options?.filters
}
};
return new Promise((resolve, reject) => {
this.messageHandlers.set(id, (response: WSResponseMessage) => {
if (response.success && response.data?.subscription_id) {
const subscriptionId = response.data.subscription_id;
// Store subscription
this.subscriptions.set(subscriptionId, {
id: subscriptionId,
entity,
schema: options?.schema,
options: { filters: options?.filters },
callback
});
this.log(`Subscribed to ${entity} with ID: ${subscriptionId}`);
resolve(subscriptionId);
} else {
reject(new Error(response.error?.message || 'Subscription failed'));
}
});
this.send(message);
// Timeout
setTimeout(() => {
if (this.messageHandlers.has(id)) {
this.messageHandlers.delete(id);
reject(new Error('Subscription timeout'));
}
}, 10000);
});
}
/**
* Unsubscribe from entity changes
*/
async unsubscribe(subscriptionId: string): Promise<void> {
this.ensureConnected();
const id = uuidv4();
const message: WSMessage = {
id,
type: 'subscription',
operation: 'unsubscribe',
subscription_id: subscriptionId
};
return new Promise((resolve, reject) => {
this.messageHandlers.set(id, (response: WSResponseMessage) => {
if (response.success) {
this.subscriptions.delete(subscriptionId);
this.log(`Unsubscribed from ${subscriptionId}`);
resolve();
} else {
reject(new Error(response.error?.message || 'Unsubscribe failed'));
}
});
this.send(message);
// Timeout
setTimeout(() => {
if (this.messageHandlers.has(id)) {
this.messageHandlers.delete(id);
reject(new Error('Unsubscribe timeout'));
}
}, 10000);
});
}
/**
* Get list of active subscriptions
*/
getSubscriptions(): Subscription[] {
return Array.from(this.subscriptions.values());
}
/**
* Get connection state
*/
getState(): ConnectionState {
return this.state;
}
/**
* Check if connected
*/
isConnected(): boolean {
return this.ws?.readyState === WebSocket.OPEN;
}
/**
* Add event listener
*/
on<K extends keyof WebSocketClientEvents>(event: K, callback: WebSocketClientEvents[K]): void {
this.eventListeners[event] = callback as any;
}
/**
* Remove event listener
*/
off<K extends keyof WebSocketClientEvents>(event: K): void {
delete this.eventListeners[event];
}
// Private methods
private handleMessage(data: string): void {
try {
const message: WSMessage = JSON.parse(data);
this.log('Received message:', message);
this.emit('message', message);
// Handle different message types
switch (message.type) {
case 'response':
this.handleResponse(message as WSResponseMessage);
break;
case 'notification':
this.handleNotification(message as WSNotificationMessage);
break;
case 'pong':
// Heartbeat response
break;
default:
this.log('Unknown message type:', message.type);
}
} catch (error) {
this.log('Error parsing message:', error);
}
}
private handleResponse(message: WSResponseMessage): void {
const handler = this.messageHandlers.get(message.id);
if (handler) {
handler(message);
this.messageHandlers.delete(message.id);
}
}
private handleNotification(message: WSNotificationMessage): void {
const subscription = this.subscriptions.get(message.subscription_id);
if (subscription?.callback) {
subscription.callback(message);
}
}
private send(message: WSMessage): void {
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) {
throw new Error('WebSocket is not connected');
}
const data = JSON.stringify(message);
this.log('Sending message:', message);
this.ws.send(data);
}
private startHeartbeat(): void {
if (this.heartbeatTimer) {
return;
}
this.heartbeatTimer = setInterval(() => {
if (this.isConnected()) {
const pingMessage: WSMessage = {
id: uuidv4(),
type: 'ping'
};
this.send(pingMessage);
}
}, this.config.heartbeatInterval);
}
private stopHeartbeat(): void {
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
}
private setState(state: ConnectionState): void {
if (this.state !== state) {
this.state = state;
this.emit('stateChange', state);
}
}
private ensureConnected(): void {
if (!this.isConnected()) {
throw new Error('WebSocket is not connected. Call connect() first.');
}
}
private emit<K extends keyof WebSocketClientEvents>(
event: K,
...args: Parameters<WebSocketClientEvents[K]>
): void {
const listener = this.eventListeners[event];
if (listener) {
(listener as any)(...args);
}
}
private log(...args: any[]): void {
if (this.config.debug) {
console.log('[WebSocketClient]', ...args);
}
}
}
export default WebSocketClient;

View File

@@ -0,0 +1,427 @@
import { WebSocketClient } from './websocket-client';
import type { WSNotificationMessage } from './websocket-types';
/**
* Example 1: Basic Usage
*/
export async function basicUsageExample() {
// Create client
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
reconnect: true,
debug: true
});
// Connect
await client.connect();
// Read users
const users = await client.read('users', {
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
],
limit: 10,
sort: [
{ column: 'name', direction: 'asc' }
]
});
console.log('Users:', users);
// Create a user
const newUser = await client.create('users', {
name: 'John Doe',
email: 'john@example.com',
status: 'active'
}, { schema: 'public' });
console.log('Created user:', newUser);
// Update user
const updatedUser = await client.update('users', '123', {
name: 'John Updated'
}, { schema: 'public' });
console.log('Updated user:', updatedUser);
// Delete user
await client.delete('users', '123', { schema: 'public' });
// Disconnect
client.disconnect();
}
/**
* Example 2: Real-time Subscriptions
*/
export async function subscriptionExample() {
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
debug: true
});
await client.connect();
// Subscribe to user changes
const subscriptionId = await client.subscribe(
'users',
(notification: WSNotificationMessage) => {
console.log('User changed:', notification.operation, notification.data);
switch (notification.operation) {
case 'create':
console.log('New user created:', notification.data);
break;
case 'update':
console.log('User updated:', notification.data);
break;
case 'delete':
console.log('User deleted:', notification.data);
break;
}
},
{
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'active' }
]
}
);
console.log('Subscribed with ID:', subscriptionId);
// Later: unsubscribe
setTimeout(async () => {
await client.unsubscribe(subscriptionId);
console.log('Unsubscribed');
client.disconnect();
}, 60000);
}
/**
* Example 3: Event Handling
*/
export async function eventHandlingExample() {
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws'
});
// Listen to connection events
client.on('connect', () => {
console.log('Connected!');
});
client.on('disconnect', (event) => {
console.log('Disconnected:', event.code, event.reason);
});
client.on('error', (error) => {
console.error('WebSocket error:', error);
});
client.on('stateChange', (state) => {
console.log('State changed to:', state);
});
client.on('message', (message) => {
console.log('Received message:', message);
});
await client.connect();
// Your operations here...
}
/**
* Example 4: Multiple Subscriptions
*/
export async function multipleSubscriptionsExample() {
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
debug: true
});
await client.connect();
// Subscribe to users
const userSubId = await client.subscribe(
'users',
(notification) => {
console.log('[Users]', notification.operation, notification.data);
},
{ schema: 'public' }
);
// Subscribe to posts
const postSubId = await client.subscribe(
'posts',
(notification) => {
console.log('[Posts]', notification.operation, notification.data);
},
{
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'published' }
]
}
);
// Subscribe to comments
const commentSubId = await client.subscribe(
'comments',
(notification) => {
console.log('[Comments]', notification.operation, notification.data);
},
{ schema: 'public' }
);
console.log('Active subscriptions:', client.getSubscriptions());
// Clean up after 60 seconds
setTimeout(async () => {
await client.unsubscribe(userSubId);
await client.unsubscribe(postSubId);
await client.unsubscribe(commentSubId);
client.disconnect();
}, 60000);
}
/**
* Example 5: Advanced Queries
*/
export async function advancedQueriesExample() {
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws'
});
await client.connect();
// Complex query with filters, sorting, pagination, and preloading
const posts = await client.read('posts', {
schema: 'public',
filters: [
{ column: 'status', operator: 'eq', value: 'published' },
{ column: 'views', operator: 'gte', value: 100 }
],
columns: ['id', 'title', 'content', 'user_id', 'created_at'],
sort: [
{ column: 'created_at', direction: 'desc' },
{ column: 'views', direction: 'desc' }
],
preload: [
{
relation: 'user',
columns: ['id', 'name', 'email']
},
{
relation: 'comments',
columns: ['id', 'content', 'user_id'],
filters: [
{ column: 'status', operator: 'eq', value: 'approved' }
]
}
],
limit: 20,
offset: 0
});
console.log('Posts:', posts);
// Get single record by ID
const post = await client.read('posts', {
schema: 'public',
record_id: '123'
});
console.log('Single post:', post);
client.disconnect();
}
/**
* Example 6: Error Handling
*/
export async function errorHandlingExample() {
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws',
reconnect: true,
maxReconnectAttempts: 5
});
client.on('error', (error) => {
console.error('Connection error:', error);
});
client.on('stateChange', (state) => {
console.log('Connection state:', state);
});
try {
await client.connect();
try {
// Try to read non-existent entity
await client.read('nonexistent', { schema: 'public' });
} catch (error) {
console.error('Read error:', error);
}
try {
// Try to create invalid record
await client.create('users', {
// Missing required fields
}, { schema: 'public' });
} catch (error) {
console.error('Create error:', error);
}
} catch (error) {
console.error('Connection failed:', error);
} finally {
client.disconnect();
}
}
/**
* Example 7: React Integration
*/
export function reactIntegrationExample() {
const exampleCode = `
import { useEffect, useState } from 'react';
import { WebSocketClient } from '@warkypublic/resolvespec-js';
export function useWebSocket(url: string) {
const [client] = useState(() => new WebSocketClient({ url }));
const [isConnected, setIsConnected] = useState(false);
useEffect(() => {
client.on('connect', () => setIsConnected(true));
client.on('disconnect', () => setIsConnected(false));
client.connect();
return () => {
client.disconnect();
};
}, [client]);
return { client, isConnected };
}
export function UsersComponent() {
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
const [users, setUsers] = useState([]);
useEffect(() => {
if (!isConnected) return;
// Subscribe to user changes
const subscribeToUsers = async () => {
const subId = await client.subscribe('users', (notification) => {
if (notification.operation === 'create') {
setUsers(prev => [...prev, notification.data]);
} else if (notification.operation === 'update') {
setUsers(prev => prev.map(u =>
u.id === notification.data.id ? notification.data : u
));
} else if (notification.operation === 'delete') {
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
}
}, { schema: 'public' });
// Load initial users
const initialUsers = await client.read('users', {
schema: 'public',
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
});
setUsers(initialUsers);
return () => client.unsubscribe(subId);
};
subscribeToUsers();
}, [client, isConnected]);
const createUser = async (name: string, email: string) => {
await client.create('users', { name, email, status: 'active' }, {
schema: 'public'
});
};
return (
<div>
<h2>Users ({users.length})</h2>
{isConnected ? '🟢 Connected' : '🔴 Disconnected'}
{/* Render users... */}
</div>
);
}
`;
console.log(exampleCode);
}
/**
* Example 8: TypeScript with Typed Models
*/
export async function typedModelsExample() {
// Define your models
interface User {
id: number;
name: string;
email: string;
status: 'active' | 'inactive';
created_at: string;
}
interface Post {
id: number;
title: string;
content: string;
user_id: number;
status: 'draft' | 'published';
views: number;
user?: User;
}
const client = new WebSocketClient({
url: 'ws://localhost:8080/ws'
});
await client.connect();
// Type-safe operations
const users = await client.read<User[]>('users', {
schema: 'public',
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
});
const newUser = await client.create<User>('users', {
name: 'Alice',
email: 'alice@example.com',
status: 'active'
}, { schema: 'public' });
const posts = await client.read<Post[]>('posts', {
schema: 'public',
preload: [
{
relation: 'user',
columns: ['id', 'name', 'email']
}
]
});
// Type-safe subscriptions
await client.subscribe(
'users',
(notification) => {
const user = notification.data as User;
console.log('User changed:', user.name, user.email);
},
{ schema: 'public' }
);
client.disconnect();
}

View File

@@ -0,0 +1,110 @@
// WebSocket Message Types
export type MessageType = 'request' | 'response' | 'notification' | 'subscription' | 'error' | 'ping' | 'pong';
export type WSOperation = 'read' | 'create' | 'update' | 'delete' | 'subscribe' | 'unsubscribe' | 'meta';
// Re-export common types
export type { FilterOption, SortOption, PreloadOption, Operator, SortDirection } from './types';
export interface WSOptions {
filters?: import('./types').FilterOption[];
columns?: string[];
preload?: import('./types').PreloadOption[];
sort?: import('./types').SortOption[];
limit?: number;
offset?: number;
}
export interface WSMessage {
id?: string;
type: MessageType;
operation?: WSOperation;
schema?: string;
entity?: string;
record_id?: string;
data?: any;
options?: WSOptions;
subscription_id?: string;
success?: boolean;
error?: WSErrorInfo;
metadata?: Record<string, any>;
timestamp?: string;
}
export interface WSErrorInfo {
code: string;
message: string;
details?: Record<string, any>;
}
export interface WSRequestMessage {
id: string;
type: 'request';
operation: WSOperation;
schema?: string;
entity: string;
record_id?: string;
data?: any;
options?: WSOptions;
}
export interface WSResponseMessage {
id: string;
type: 'response';
success: boolean;
data?: any;
error?: WSErrorInfo;
metadata?: Record<string, any>;
timestamp: string;
}
export interface WSNotificationMessage {
type: 'notification';
operation: WSOperation;
subscription_id: string;
schema?: string;
entity: string;
data: any;
timestamp: string;
}
export interface WSSubscriptionMessage {
id: string;
type: 'subscription';
operation: 'subscribe' | 'unsubscribe';
schema?: string;
entity: string;
options?: WSOptions;
subscription_id?: string;
}
export interface SubscriptionOptions {
filters?: import('./types').FilterOption[];
onNotification?: (notification: WSNotificationMessage) => void;
}
export interface WebSocketClientConfig {
url: string;
reconnect?: boolean;
reconnectInterval?: number;
maxReconnectAttempts?: number;
heartbeatInterval?: number;
debug?: boolean;
}
export interface Subscription {
id: string;
entity: string;
schema?: string;
options?: WSOptions;
callback?: (notification: WSNotificationMessage) => void;
}
export type ConnectionState = 'connecting' | 'connected' | 'disconnecting' | 'disconnected' | 'reconnecting';
export interface WebSocketClientEvents {
connect: () => void;
disconnect: (event: CloseEvent) => void;
error: (error: Error) => void;
message: (message: WSMessage) => void;
stateChange: (state: ConnectionState) => void;
}

View File

@@ -14,33 +14,33 @@ NC='\033[0m' # No Color
echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n" echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n"
# Check if docker-compose is available # Check if podman compose is available
if ! command -v docker-compose &> /dev/null; then if ! command -v podman &> /dev/null; then
echo -e "${RED}Error: docker-compose is not installed${NC}" echo -e "${RED}Error: podman is not installed${NC}"
echo "Please install docker-compose or run PostgreSQL manually" echo "Please install podman or run PostgreSQL manually"
echo "See INTEGRATION_TESTS.md for details" echo "See INTEGRATION_TESTS.md for details"
exit 1 exit 1
fi fi
# Clean up any existing containers and networks from previous runs # Clean up any existing containers and networks from previous runs
echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}" echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}"
docker-compose down -v 2>/dev/null || true podman compose down -v 2>/dev/null || true
# Start PostgreSQL # Start PostgreSQL
echo -e "${YELLOW}Starting PostgreSQL...${NC}" echo -e "${YELLOW}Starting PostgreSQL...${NC}"
docker-compose up -d postgres-test podman compose up -d postgres-test
# Wait for PostgreSQL to be ready # Wait for PostgreSQL to be ready
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}" echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
max_attempts=30 max_attempts=30
attempt=0 attempt=0
while ! docker-compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do while ! podman compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do
attempt=$((attempt + 1)) attempt=$((attempt + 1))
if [ $attempt -ge $max_attempts ]; then if [ $attempt -ge $max_attempts ]; then
echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}" echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}"
docker-compose logs postgres-test podman compose logs postgres-test
docker-compose down podman compose down
exit 1 exit 1
fi fi
sleep 1 sleep 1
@@ -51,8 +51,8 @@ echo -e "\n${GREEN}PostgreSQL is ready!${NC}\n"
# Create test databases # Create test databases
echo -e "${YELLOW}Creating test databases...${NC}" echo -e "${YELLOW}Creating test databases...${NC}"
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists" podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists"
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists" podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists"
echo -e "${GREEN}Test databases ready!${NC}\n" echo -e "${GREEN}Test databases ready!${NC}\n"
# Determine which tests to run # Determine which tests to run
@@ -79,6 +79,6 @@ fi
# Cleanup # Cleanup
echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}" echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}"
docker-compose down podman compose down
exit $EXIT_CODE exit $EXIT_CODE

View File

@@ -19,14 +19,14 @@ Integration tests validate the full functionality of both `pkg/resolvespec` and
- Go 1.19 or later - Go 1.19 or later
- PostgreSQL 12 or later - PostgreSQL 12 or later
- Docker and Docker Compose (optional, for easy setup) - Podman and Podman Compose (optional, for easy setup)
## Quick Start with Docker ## Quick Start with Podman
### 1. Start PostgreSQL with Docker Compose ### 1. Start PostgreSQL with Podman Compose
```bash ```bash
docker-compose up -d postgres-test podman compose up -d postgres-test
``` ```
This starts a PostgreSQL container with the following default settings: This starts a PostgreSQL container with the following default settings:
@@ -52,7 +52,7 @@ go test -tags=integration ./pkg/restheadspec -v
### 3. Stop PostgreSQL ### 3. Stop PostgreSQL
```bash ```bash
docker-compose down podman compose down
``` ```
## Manual PostgreSQL Setup ## Manual PostgreSQL Setup
@@ -161,7 +161,7 @@ If you see "connection refused" errors:
1. Check that PostgreSQL is running: 1. Check that PostgreSQL is running:
```bash ```bash
docker-compose ps podman compose ps
``` ```
2. Verify connection parameters: 2. Verify connection parameters:
@@ -194,10 +194,10 @@ Each test automatically cleans up its data using `TRUNCATE`. If you need a fresh
```bash ```bash
# Stop and remove containers (removes data) # Stop and remove containers (removes data)
docker-compose down -v podman compose down -v
# Restart # Restart
docker-compose up -d postgres-test podman compose up -d postgres-test
``` ```
## CI/CD Integration ## CI/CD Integration

View File

@@ -119,13 +119,13 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
- PostgreSQL 12+ installed and running - PostgreSQL 12+ installed and running
- Create test databases manually (see below) - Create test databases manually (see below)
### Setup with Docker ### Setup with Podman
1. **Start PostgreSQL**: 1. **Start PostgreSQL**:
```bash ```bash
make docker-up make docker-up
# or # or
docker-compose up -d postgres-test podman compose up -d postgres-test
``` ```
2. **Run Tests**: 2. **Run Tests**:
@@ -141,10 +141,10 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
```bash ```bash
make docker-down make docker-down
# or # or
docker-compose down podman compose down
``` ```
### Setup without Docker ### Setup without Podman
1. **Create Databases**: 1. **Create Databases**:
```sql ```sql
@@ -289,8 +289,8 @@ go test -tags=integration ./pkg/resolvespec -v
**Problem**: "connection refused" or "database does not exist" **Problem**: "connection refused" or "database does not exist"
**Solutions**: **Solutions**:
1. Check PostgreSQL is running: `docker-compose ps` 1. Check PostgreSQL is running: `podman compose ps`
2. Verify databases exist: `docker-compose exec postgres-test psql -U postgres -l` 2. Verify databases exist: `podman compose exec postgres-test psql -U postgres -l`
3. Check environment variable: `echo $TEST_DATABASE_URL` 3. Check environment variable: `echo $TEST_DATABASE_URL`
4. Recreate databases: `make clean && make docker-up` 4. Recreate databases: `make clean && make docker-up`