mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06b2404c0c | ||
|
|
32007480c6 | ||
|
|
ab1ce869b6 | ||
|
|
ff72e04428 | ||
|
|
e35f8a4f14 | ||
|
|
5ff9a8a24e | ||
|
|
81b87af6e4 | ||
|
|
f3ba314640 | ||
|
|
93df33e274 | ||
|
|
abd045493a | ||
|
|
a61556d857 | ||
|
|
eaf1133575 | ||
|
|
8172c0495d | ||
|
|
7a3c368121 | ||
|
|
9c5c7689e9 | ||
|
|
08050c960d |
92
README.md
92
README.md
@@ -13,6 +13,8 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
|||||||
|
|
||||||
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering.
|
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering.
|
||||||
|
|
||||||
|
**🆕 New in v3.0**: Explicit route registration - Routes are now created per registered model for better flexibility and control. OPTIONS method support with full CORS headers for cross-origin requests.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
@@ -65,6 +67,12 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
|||||||
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
|
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
|
||||||
- **🆕 Base64 Encoding**: Support for base64-encoded header values
|
- **🆕 Base64 Encoding**: Support for base64-encoded header values
|
||||||
|
|
||||||
|
### Routing & CORS (v3.0+)
|
||||||
|
- **🆕 Explicit Route Registration**: Routes created per registered model instead of dynamic lookups
|
||||||
|
- **🆕 OPTIONS Method Support**: Full OPTIONS method support returning model metadata
|
||||||
|
- **🆕 CORS Headers**: Comprehensive CORS support with all HeadSpec headers allowed
|
||||||
|
- **🆕 Better Route Control**: Customize routes per model with more flexibility
|
||||||
|
|
||||||
## API Structure
|
## API Structure
|
||||||
|
|
||||||
### URL Patterns
|
### URL Patterns
|
||||||
@@ -123,13 +131,15 @@ import "github.com/gorilla/mux"
|
|||||||
// Create handler
|
// Create handler
|
||||||
handler := restheadspec.NewHandlerWithGORM(db)
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
// Register models using schema.table format
|
// IMPORTANT: Register models BEFORE setting up routes
|
||||||
|
// Routes are created explicitly for each registered model
|
||||||
handler.Registry.RegisterModel("public.users", &User{})
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
handler.Registry.RegisterModel("public.posts", &Post{})
|
handler.Registry.RegisterModel("public.posts", &Post{})
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes (creates explicit routes for each registered model)
|
||||||
|
// This replaces the old dynamic route lookup approach
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
restheadspec.SetupMuxRoutes(router, handler)
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
http.ListenAndServe(":8080", router)
|
http.ListenAndServe(":8080", router)
|
||||||
@@ -172,6 +182,42 @@ restheadspec.SetupMuxRoutes(router, handler)
|
|||||||
|
|
||||||
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
|
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
|
||||||
|
|
||||||
|
### CORS & OPTIONS Support
|
||||||
|
|
||||||
|
ResolveSpec and RestHeadSpec include comprehensive CORS support for cross-origin requests:
|
||||||
|
|
||||||
|
**OPTIONS Method**:
|
||||||
|
```http
|
||||||
|
OPTIONS /public/users HTTP/1.1
|
||||||
|
```
|
||||||
|
Returns metadata with appropriate CORS headers:
|
||||||
|
```http
|
||||||
|
Access-Control-Allow-Origin: *
|
||||||
|
Access-Control-Allow-Methods: GET, POST, OPTIONS
|
||||||
|
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
|
||||||
|
Access-Control-Max-Age: 86400
|
||||||
|
Access-Control-Allow-Credentials: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- OPTIONS returns model metadata (same as GET metadata endpoint)
|
||||||
|
- All HTTP methods include CORS headers automatically
|
||||||
|
- OPTIONS requests don't require authentication (CORS preflight)
|
||||||
|
- Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
|
||||||
|
- 24-hour max age to reduce preflight requests
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
|
||||||
|
// Get default CORS config
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
|
||||||
|
// Customize if needed
|
||||||
|
corsConfig.AllowedOrigins = []string{"https://example.com"}
|
||||||
|
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||||
|
```
|
||||||
|
|
||||||
### Lifecycle Hooks
|
### Lifecycle Hooks
|
||||||
|
|
||||||
RestHeadSpec supports lifecycle hooks for all CRUD operations:
|
RestHeadSpec supports lifecycle hooks for all CRUD operations:
|
||||||
@@ -687,15 +733,16 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
|
|||||||
```go
|
```go
|
||||||
import "github.com/gorilla/mux"
|
import "github.com/gorilla/mux"
|
||||||
|
|
||||||
// Backward compatible way
|
// Register models first
|
||||||
router := mux.NewRouter()
|
handler.Registry.RegisterModel("public.users", &User{})
|
||||||
resolvespec.SetupRoutes(router, handler)
|
handler.Registry.RegisterModel("public.posts", &Post{})
|
||||||
|
|
||||||
// Or manually:
|
// Setup routes - creates explicit routes for each model
|
||||||
router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
router := mux.NewRouter()
|
||||||
vars := mux.Vars(r)
|
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||||
handler.Handle(w, r, vars)
|
|
||||||
}).Methods("POST")
|
// Routes created: /public/users, /public/posts, etc.
|
||||||
|
// Each route includes GET, POST, and OPTIONS methods with CORS support
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Gin (Custom Integration)
|
#### Gin (Custom Integration)
|
||||||
@@ -950,7 +997,28 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
### v2.1 (Latest)
|
### v3.0 (Latest - December 2025)
|
||||||
|
|
||||||
|
**Explicit Route Registration (🆕)**:
|
||||||
|
- **Breaking Change**: Routes are now created explicitly for each registered model
|
||||||
|
- **Better Control**: Customize routes per model with more flexibility
|
||||||
|
- **Registration Order**: Models must be registered BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
||||||
|
- **Benefits**: More flexible routing, easier to add custom routes per model, better performance
|
||||||
|
|
||||||
|
**OPTIONS Method & CORS Support (🆕)**:
|
||||||
|
- **OPTIONS Endpoint**: Full OPTIONS method support for CORS preflight requests
|
||||||
|
- **Metadata Response**: OPTIONS returns model metadata (same as GET /metadata)
|
||||||
|
- **CORS Headers**: Comprehensive CORS headers on all responses
|
||||||
|
- **Header Support**: All HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.) allowed
|
||||||
|
- **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
||||||
|
- **Configurable**: Customize CORS settings via `common.CORSConfig`
|
||||||
|
|
||||||
|
**Migration Notes**:
|
||||||
|
- Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
||||||
|
- Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
|
||||||
|
- This is a **breaking change** but provides better control and flexibility
|
||||||
|
|
||||||
|
### v2.1
|
||||||
|
|
||||||
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
|
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
|
||||||
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
|
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ func main() {
|
|||||||
handler.RegisterModel("public", modelNames[i], model)
|
handler.RegisterModel("public", modelNames[i], model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup routes using new SetupMuxRoutes function
|
// Setup routes using new SetupMuxRoutes function (without authentication)
|
||||||
resolvespec.SetupMuxRoutes(r, handler)
|
resolvespec.SetupMuxRoutes(r, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
logger.Info("Starting server on :8080")
|
logger.Info("Starting server on :8080")
|
||||||
|
|||||||
@@ -147,8 +147,11 @@ func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||||
b.query = b.query.ColumnExpr(query, args)
|
if len(args) > 0 {
|
||||||
|
b.query = b.query.ColumnExpr(query, args)
|
||||||
|
} else {
|
||||||
|
b.query = b.query.ColumnExpr(query)
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,12 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Select(query, args...)
|
if len(args) > 0 {
|
||||||
|
g.db = g.db.Select(query, args...)
|
||||||
|
} else {
|
||||||
|
g.db = g.db.Select(query)
|
||||||
|
}
|
||||||
|
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnderlyingRequest returns the underlying *http.Request
|
||||||
|
// This is useful when you need to pass the request to other handlers
|
||||||
|
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
|
||||||
|
return b.req.Request
|
||||||
|
}
|
||||||
|
|
||||||
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
|
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
|
||||||
type StandardBunRouterAdapter struct {
|
type StandardBunRouterAdapter struct {
|
||||||
*BunRouterAdapter
|
*BunRouterAdapter
|
||||||
|
|||||||
@@ -137,6 +137,12 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnderlyingRequest returns the underlying *http.Request
|
||||||
|
// This is useful when you need to pass the request to other handlers
|
||||||
|
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
|
||||||
|
return h.req
|
||||||
|
}
|
||||||
|
|
||||||
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
|
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
|
||||||
type HTTPResponseWriter struct {
|
type HTTPResponseWriter struct {
|
||||||
resp http.ResponseWriter
|
resp http.ResponseWriter
|
||||||
@@ -166,6 +172,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
|||||||
return json.NewEncoder(h.resp).Encode(data)
|
return json.NewEncoder(h.resp).Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
||||||
|
// This is useful when you need to pass the response writer to other handlers
|
||||||
|
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||||
|
return h.resp
|
||||||
|
}
|
||||||
|
|
||||||
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
|
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
|
||||||
type StandardMuxAdapter struct {
|
type StandardMuxAdapter struct {
|
||||||
*MuxAdapter
|
*MuxAdapter
|
||||||
|
|||||||
119
pkg/common/cors.go
Normal file
119
pkg/common/cors.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CORSConfig holds CORS configuration
|
||||||
|
type CORSConfig struct {
|
||||||
|
AllowedOrigins []string
|
||||||
|
AllowedMethods []string
|
||||||
|
AllowedHeaders []string
|
||||||
|
MaxAge int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
|
||||||
|
func DefaultCORSConfig() CORSConfig {
|
||||||
|
return CORSConfig{
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||||
|
AllowedHeaders: GetHeadSpecHeaders(),
|
||||||
|
MaxAge: 86400, // 24 hours
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHeadSpecHeaders returns all headers used by HeadSpec
|
||||||
|
func GetHeadSpecHeaders() []string {
|
||||||
|
return []string{
|
||||||
|
// Standard headers
|
||||||
|
"Content-Type",
|
||||||
|
"Authorization",
|
||||||
|
"Accept",
|
||||||
|
"Accept-Language",
|
||||||
|
"Content-Language",
|
||||||
|
|
||||||
|
// Field Selection
|
||||||
|
"X-Select-Fields",
|
||||||
|
"X-Not-Select-Fields",
|
||||||
|
"X-Clean-JSON",
|
||||||
|
|
||||||
|
// Filtering & Search
|
||||||
|
"X-FieldFilter-*",
|
||||||
|
"X-SearchFilter-*",
|
||||||
|
"X-SearchOp-*",
|
||||||
|
"X-SearchOr-*",
|
||||||
|
"X-SearchAnd-*",
|
||||||
|
"X-SearchCols",
|
||||||
|
"X-Custom-SQL-W",
|
||||||
|
"X-Custom-SQL-W-*",
|
||||||
|
"X-Custom-SQL-Or",
|
||||||
|
"X-Custom-SQL-Or-*",
|
||||||
|
|
||||||
|
// Joins & Relations
|
||||||
|
"X-Preload",
|
||||||
|
"X-Preload-*",
|
||||||
|
"X-Expand",
|
||||||
|
"X-Expand-*",
|
||||||
|
"X-Custom-SQL-Join",
|
||||||
|
"X-Custom-SQL-Join-*",
|
||||||
|
|
||||||
|
// Sorting & Pagination
|
||||||
|
"X-Sort",
|
||||||
|
"X-Sort-*",
|
||||||
|
"X-Limit",
|
||||||
|
"X-Offset",
|
||||||
|
"X-Cursor-Forward",
|
||||||
|
"X-Cursor-Backward",
|
||||||
|
|
||||||
|
// Advanced Features
|
||||||
|
"X-AdvSQL-*",
|
||||||
|
"X-CQL-Sel-*",
|
||||||
|
"X-Distinct",
|
||||||
|
"X-SkipCount",
|
||||||
|
"X-SkipCache",
|
||||||
|
"X-Fetch-RowNumber",
|
||||||
|
"X-PKRow",
|
||||||
|
|
||||||
|
// Response Format
|
||||||
|
"X-SimpleAPI",
|
||||||
|
"X-DetailAPI",
|
||||||
|
"X-Syncfusion",
|
||||||
|
"X-Single-Record-As-Object",
|
||||||
|
|
||||||
|
// Transaction Control
|
||||||
|
"X-Transaction-Atomic",
|
||||||
|
|
||||||
|
// X-Files - comprehensive JSON configuration
|
||||||
|
"X-Files",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCORSHeaders sets CORS headers on a response writer
|
||||||
|
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||||
|
// Set allowed origins
|
||||||
|
if len(config.AllowedOrigins) > 0 {
|
||||||
|
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set allowed methods
|
||||||
|
if len(config.AllowedMethods) > 0 {
|
||||||
|
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set allowed headers
|
||||||
|
if len(config.AllowedHeaders) > 0 {
|
||||||
|
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set max age
|
||||||
|
if config.MaxAge > 0 {
|
||||||
|
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow credentials
|
||||||
|
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
|
// Expose headers that clients can read
|
||||||
|
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
|
||||||
|
}
|
||||||
97
pkg/common/handler_example.go
Normal file
97
pkg/common/handler_example.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
// Example showing how to use the common handler interfaces
|
||||||
|
// This file demonstrates the handler interface hierarchy and usage patterns
|
||||||
|
|
||||||
|
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
|
||||||
|
// which works with any handler type (resolvespec, restheadspec, or funcspec)
|
||||||
|
func ProcessWithAnyHandler(handler SpecHandler) Database {
|
||||||
|
// All handlers expose GetDatabase() through the SpecHandler interface
|
||||||
|
return handler.GetDatabase()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
|
||||||
|
// which works with resolvespec.Handler and restheadspec.Handler
|
||||||
|
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
|
||||||
|
// Both resolvespec and restheadspec handlers implement Handle()
|
||||||
|
handler.Handle(w, r, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
|
||||||
|
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
|
||||||
|
// Both resolvespec and restheadspec handlers implement HandleGet()
|
||||||
|
handler.HandleGet(w, r, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example usage patterns (not executable, just for documentation):
|
||||||
|
/*
|
||||||
|
// Example 1: Using with resolvespec.Handler
|
||||||
|
func ExampleResolveSpec() {
|
||||||
|
db := // ... get database
|
||||||
|
registry := // ... get registry
|
||||||
|
|
||||||
|
handler := resolvespec.NewHandler(db, registry)
|
||||||
|
|
||||||
|
// Can be used as SpecHandler
|
||||||
|
var specHandler SpecHandler = handler
|
||||||
|
database := specHandler.GetDatabase()
|
||||||
|
|
||||||
|
// Can be used as CRUDHandler
|
||||||
|
var crudHandler CRUDHandler = handler
|
||||||
|
crudHandler.Handle(w, r, params)
|
||||||
|
crudHandler.HandleGet(w, r, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 2: Using with restheadspec.Handler
|
||||||
|
func ExampleRestHeadSpec() {
|
||||||
|
db := // ... get database
|
||||||
|
registry := // ... get registry
|
||||||
|
|
||||||
|
handler := restheadspec.NewHandler(db, registry)
|
||||||
|
|
||||||
|
// Can be used as SpecHandler
|
||||||
|
var specHandler SpecHandler = handler
|
||||||
|
database := specHandler.GetDatabase()
|
||||||
|
|
||||||
|
// Can be used as CRUDHandler
|
||||||
|
var crudHandler CRUDHandler = handler
|
||||||
|
crudHandler.Handle(w, r, params)
|
||||||
|
crudHandler.HandleGet(w, r, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 3: Using with funcspec.Handler
|
||||||
|
func ExampleFuncSpec() {
|
||||||
|
db := // ... get database
|
||||||
|
|
||||||
|
handler := funcspec.NewHandler(db)
|
||||||
|
|
||||||
|
// Can be used as SpecHandler
|
||||||
|
var specHandler SpecHandler = handler
|
||||||
|
database := specHandler.GetDatabase()
|
||||||
|
|
||||||
|
// Can be used as QueryHandler
|
||||||
|
var queryHandler QueryHandler = handler
|
||||||
|
// funcspec has different methods: SqlQueryList() and SqlQuery()
|
||||||
|
// which return HTTP handler functions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 4: Polymorphic handler processing
|
||||||
|
func ProcessHandlers(handlers []SpecHandler) {
|
||||||
|
for _, handler := range handlers {
|
||||||
|
// All handlers expose the database
|
||||||
|
db := handler.GetDatabase()
|
||||||
|
|
||||||
|
// Type switch for specific handler types
|
||||||
|
switch h := handler.(type) {
|
||||||
|
case CRUDHandler:
|
||||||
|
// This is resolvespec or restheadspec
|
||||||
|
// Can call Handle() and HandleGet()
|
||||||
|
_ = h
|
||||||
|
case QueryHandler:
|
||||||
|
// This is funcspec
|
||||||
|
// Can call SqlQueryList() and SqlQuery()
|
||||||
|
_ = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
@@ -122,6 +122,7 @@ type Request interface {
|
|||||||
PathParam(key string) string
|
PathParam(key string) string
|
||||||
QueryParam(key string) string
|
QueryParam(key string) string
|
||||||
AllQueryParams() map[string]string // Get all query parameters as a map
|
AllQueryParams() map[string]string // Get all query parameters as a map
|
||||||
|
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriter interface abstracts HTTP response
|
// ResponseWriter interface abstracts HTTP response
|
||||||
@@ -130,6 +131,7 @@ type ResponseWriter interface {
|
|||||||
WriteHeader(statusCode int)
|
WriteHeader(statusCode int)
|
||||||
Write(data []byte) (int, error)
|
Write(data []byte) (int, error)
|
||||||
WriteJSON(data interface{}) error
|
WriteJSON(data interface{}) error
|
||||||
|
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPHandlerFunc type for HTTP handlers
|
// HTTPHandlerFunc type for HTTP handlers
|
||||||
@@ -164,6 +166,10 @@ func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
|||||||
return json.NewEncoder(s.w).Encode(data)
|
return json.NewEncoder(s.w).Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||||
|
return s.w
|
||||||
|
}
|
||||||
|
|
||||||
// StandardRequest adapts *http.Request to Request interface
|
// StandardRequest adapts *http.Request to Request interface
|
||||||
type StandardRequest struct {
|
type StandardRequest struct {
|
||||||
r *http.Request
|
r *http.Request
|
||||||
@@ -228,6 +234,10 @@ func (s *StandardRequest) AllQueryParams() map[string]string {
|
|||||||
return params
|
return params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) UnderlyingRequest() *http.Request {
|
||||||
|
return s.r
|
||||||
|
}
|
||||||
|
|
||||||
// TableNameProvider interface for models that provide table names
|
// TableNameProvider interface for models that provide table names
|
||||||
type TableNameProvider interface {
|
type TableNameProvider interface {
|
||||||
TableName() string
|
TableName() string
|
||||||
@@ -246,3 +256,39 @@ type PrimaryKeyNameProvider interface {
|
|||||||
type SchemaProvider interface {
|
type SchemaProvider interface {
|
||||||
SchemaName() string
|
SchemaName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SpecHandler interface represents common functionality across all spec handlers
|
||||||
|
// This is the base interface implemented by:
|
||||||
|
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
|
||||||
|
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
|
||||||
|
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
|
||||||
|
//
|
||||||
|
// The interface hierarchy is:
|
||||||
|
//
|
||||||
|
// SpecHandler (base)
|
||||||
|
// ├── CRUDHandler (resolvespec, restheadspec)
|
||||||
|
// └── QueryHandler (funcspec)
|
||||||
|
type SpecHandler interface {
|
||||||
|
// GetDatabase returns the underlying database connection
|
||||||
|
GetDatabase() Database
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRUDHandler interface for handlers that support CRUD operations
|
||||||
|
// This is implemented by resolvespec.Handler and restheadspec.Handler
|
||||||
|
type CRUDHandler interface {
|
||||||
|
SpecHandler
|
||||||
|
|
||||||
|
// Handle processes API requests through router-agnostic interface
|
||||||
|
Handle(w ResponseWriter, r Request, params map[string]string)
|
||||||
|
|
||||||
|
// HandleGet processes GET requests for metadata
|
||||||
|
HandleGet(w ResponseWriter, r Request, params map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryHandler interface for handlers that execute SQL queries
|
||||||
|
// This is implemented by funcspec.Handler
|
||||||
|
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
|
||||||
|
type QueryHandler interface {
|
||||||
|
SpecHandler
|
||||||
|
// Methods are defined in funcspec package due to different function signature requirements
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,12 @@ func NewHandler(db common.Database) *Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDatabase returns the underlying database connection
|
||||||
|
// Implements common.SpecHandler interface
|
||||||
|
func (h *Handler) GetDatabase() common.Database {
|
||||||
|
return h.db
|
||||||
|
}
|
||||||
|
|
||||||
// Hooks returns the hook registry for this handler
|
// Hooks returns the hook registry for this handler
|
||||||
// Use this to register custom hooks for operations
|
// Use this to register custom hooks for operations
|
||||||
func (h *Handler) Hooks() *HookRegistry {
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
@@ -157,8 +163,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
// Remove unused input variables
|
// Remove unused input variables
|
||||||
if pBlankparms {
|
if pBlankparms {
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
logger.Debug("Removed unused variable: %s", kw)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
|
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +238,8 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dbobjlist = rows
|
// Normalize PostgreSQL types for proper JSON marshaling
|
||||||
|
dbobjlist = normalizePostgresTypesList(rows)
|
||||||
|
|
||||||
if pNoCount {
|
if pNoCount {
|
||||||
total = int64(len(dbobjlist))
|
total = int64(len(dbobjlist))
|
||||||
@@ -495,8 +503,9 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
// Remove unused input variables
|
// Remove unused input variables
|
||||||
if pBlankparms {
|
if pBlankparms {
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
logger.Debug("Removed unused variable: %s", kw)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
|
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -524,7 +533,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(rows) > 0 {
|
if len(rows) > 0 {
|
||||||
dbobj = rows[0]
|
dbobj = normalizePostgresTypes(rows[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterSQLExec hook
|
// Execute AfterSQLExec hook
|
||||||
@@ -749,8 +758,8 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[rid_session]") {
|
if strings.Contains(sqlquery, "[rid_session]") {
|
||||||
sessionID := userCtx.SessionID
|
sessionID, _ := strconv.ParseInt(userCtx.SessionID, 10, 64)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("'%s'", sessionID))
|
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("%d", sessionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(sqlquery, "[method]") {
|
if strings.Contains(sqlquery, "[method]") {
|
||||||
@@ -864,6 +873,38 @@ func IsNumeric(s string) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
||||||
|
// based on whether it appears within quotes in the SQL query.
|
||||||
|
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
||||||
|
func getReplacementForBlankParam(sqlquery, param string) string {
|
||||||
|
// Find the parameter in the query
|
||||||
|
idx := strings.Index(sqlquery, param)
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check characters immediately before and after the parameter
|
||||||
|
var charBefore, charAfter byte
|
||||||
|
|
||||||
|
if idx > 0 {
|
||||||
|
charBefore = sqlquery[idx-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
endIdx := idx + len(param)
|
||||||
|
if endIdx < len(sqlquery) {
|
||||||
|
charAfter = sqlquery[endIdx]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if parameter is surrounded by quotes (single quote or dollar sign for PostgreSQL dollar-quoted strings)
|
||||||
|
if (charBefore == '\'' || charBefore == '$') && (charAfter == '\'' || charAfter == '$') {
|
||||||
|
// Parameter is in quotes, return empty string
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameter is not in quotes, return NULL
|
||||||
|
return "NULL"
|
||||||
|
}
|
||||||
|
|
||||||
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
|
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
|
||||||
// func makeResultReceiver(length int) []interface{} {
|
// func makeResultReceiver(length int) []interface{} {
|
||||||
// result := make([]interface{}, length)
|
// result := make([]interface{}, length)
|
||||||
@@ -906,3 +947,67 @@ func sendError(w http.ResponseWriter, status int, code, message string, err erro
|
|||||||
})
|
})
|
||||||
_, _ = w.Write(data)
|
_, _ = w.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizePostgresTypesList normalizes a list of result maps to handle PostgreSQL types correctly
|
||||||
|
func normalizePostgresTypesList(rows []map[string]interface{}) []map[string]interface{} {
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make([]map[string]interface{}, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
normalized[i] = normalizePostgresTypes(row)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizePostgresTypes normalizes a result map to handle PostgreSQL types correctly for JSON marshaling
|
||||||
|
// This is necessary because when scanning into map[string]interface{}, PostgreSQL types like jsonb, bytea, etc.
|
||||||
|
// are scanned as []byte which would be base64-encoded when marshaled to JSON.
|
||||||
|
func normalizePostgresTypes(row map[string]interface{}) map[string]interface{} {
|
||||||
|
if row == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make(map[string]interface{}, len(row))
|
||||||
|
for key, value := range row {
|
||||||
|
normalized[key] = normalizePostgresValue(value)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizePostgresValue normalizes a single value to the appropriate Go type for JSON marshaling
|
||||||
|
func normalizePostgresValue(value interface{}) interface{} {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []byte:
|
||||||
|
// Check if it's valid JSON (jsonb type)
|
||||||
|
// Try to unmarshal as JSON first
|
||||||
|
var jsonObj interface{}
|
||||||
|
if err := json.Unmarshal(v, &jsonObj); err == nil {
|
||||||
|
// It's valid JSON, return as json.RawMessage so it's not double-encoded
|
||||||
|
return json.RawMessage(v)
|
||||||
|
}
|
||||||
|
// Not valid JSON, could be bytea - keep as []byte for base64 encoding
|
||||||
|
return v
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
// Recursively normalize array elements
|
||||||
|
normalized := make([]interface{}, len(v))
|
||||||
|
for i, elem := range v {
|
||||||
|
normalized[i] = normalizePostgresValue(elem)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Recursively normalize nested maps
|
||||||
|
return normalizePostgresTypes(v)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// For other types (int, float, string, bool, etc.), return as-is
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -784,7 +784,7 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
userCtx := &security.UserContext{
|
userCtx := &security.UserContext{
|
||||||
UserID: 123,
|
UserID: 123,
|
||||||
UserName: "testuser",
|
UserName: "testuser",
|
||||||
SessionID: "session-abc",
|
SessionID: "456",
|
||||||
}
|
}
|
||||||
|
|
||||||
metainfo := map[string]interface{}{
|
metainfo := map[string]interface{}{
|
||||||
@@ -819,7 +819,7 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
name: "Replace [rid_session]",
|
name: "Replace [rid_session]",
|
||||||
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
|
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
|
||||||
expectedCheck: func(result string) bool {
|
expectedCheck: func(result string) bool {
|
||||||
return strings.Contains(result, "'session-abc'")
|
return strings.Contains(result, "456")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -835,3 +835,65 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
|
||||||
|
func TestGetReplacementForBlankParam(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
param string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parameter in single quotes",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
|
||||||
|
param: "[username]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter in dollar quotes",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
|
||||||
|
param: "[jsondata]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter not in quotes",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||||
|
param: "[user_id]",
|
||||||
|
expected: "NULL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter not in quotes with AND",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
|
||||||
|
param: "[user_id]",
|
||||||
|
expected: "NULL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter in mixed quote context - before quote",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
|
||||||
|
param: "[user_id]",
|
||||||
|
expected: "NULL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter in mixed quote context - in quotes",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
|
||||||
|
param: "[username]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter with dollar quote tag",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
|
||||||
|
param: "[content]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
83
pkg/funcspec/security_adapter.go
Normal file
83
pkg/funcspec/security_adapter.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||||
|
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||||
|
// We provide audit logging for data access tracking
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||||
|
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: BeforeQuery - Audit logging before single query execution
|
||||||
|
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Note: Row-level security and column masking are challenging in funcspec
|
||||||
|
// because the SQL query is fully user-defined. Security should be implemented
|
||||||
|
// at the SQL function level or through database policies (RLS).
|
||||||
|
}
|
||||||
|
|
||||||
|
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
|
||||||
|
type funcSpecSecurityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &funcSpecSecurityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetContext() context.Context {
|
||||||
|
return f.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
|
||||||
|
if f.ctx.UserContext == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(f.ctx.UserContext.UserID), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetSchema() string {
|
||||||
|
// funcspec doesn't have a schema concept, extract from SQL query or use default
|
||||||
|
return "public"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetEntity() string {
|
||||||
|
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
|
||||||
|
return "sql_query"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetModel() interface{} {
|
||||||
|
// funcspec doesn't use models in the same way as restheadspec
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetQuery() interface{} {
|
||||||
|
// In funcspec, the query is a string, not a query builder object
|
||||||
|
return f.ctx.SQLQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
|
||||||
|
// In funcspec, we could modify the SQL string, but this should be done cautiously
|
||||||
|
if sqlQuery, ok := query.(string); ok {
|
||||||
|
f.ctx.SQLQuery = sqlQuery
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) GetResult() interface{} {
|
||||||
|
return f.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
|
||||||
|
f.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -23,6 +23,15 @@ func Init(dev bool) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UpdateLoggerPath(path string, dev bool) {
|
||||||
|
defaultConfig := zap.NewProductionConfig()
|
||||||
|
if dev {
|
||||||
|
defaultConfig = zap.NewDevelopmentConfig()
|
||||||
|
}
|
||||||
|
defaultConfig.OutputPaths = []string{path}
|
||||||
|
UpdateLogger(&defaultConfig)
|
||||||
|
}
|
||||||
|
|
||||||
func UpdateLogger(config *zap.Config) {
|
func UpdateLogger(config *zap.Config) {
|
||||||
defaultConfig := zap.NewProductionConfig()
|
defaultConfig := zap.NewProductionConfig()
|
||||||
defaultConfig.OutputPaths = []string{"resolvespec.log"}
|
defaultConfig.OutputPaths = []string{"resolvespec.log"}
|
||||||
|
|||||||
@@ -16,11 +16,17 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FallbackHandler is a function that handles requests when no model is found
|
||||||
|
// It receives the same parameters as the Handle method
|
||||||
|
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
|
||||||
|
|
||||||
// Handler handles API requests using database and model abstractions
|
// Handler handles API requests using database and model abstractions
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db common.Database
|
db common.Database
|
||||||
registry common.ModelRegistry
|
registry common.ModelRegistry
|
||||||
nestedProcessor *common.NestedCUDProcessor
|
nestedProcessor *common.NestedCUDProcessor
|
||||||
|
hooks *HookRegistry
|
||||||
|
fallbackHandler FallbackHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new API handler with database and registry abstractions
|
// NewHandler creates a new API handler with database and registry abstractions
|
||||||
@@ -28,12 +34,31 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|||||||
handler := &Handler{
|
handler := &Handler{
|
||||||
db: db,
|
db: db,
|
||||||
registry: registry,
|
registry: registry,
|
||||||
|
hooks: NewHookRegistry(),
|
||||||
}
|
}
|
||||||
// Initialize nested processor
|
// Initialize nested processor
|
||||||
handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
|
handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hooks returns the hook registry for this handler
|
||||||
|
// Use this to register custom hooks for operations
|
||||||
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
|
return h.hooks
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackHandler sets a fallback handler to be called when no model is found
|
||||||
|
// If not set, the handler will simply return (pass through to next route)
|
||||||
|
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
|
||||||
|
h.fallbackHandler = fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDatabase returns the underlying database connection
|
||||||
|
// Implements common.SpecHandler interface
|
||||||
|
func (h *Handler) GetDatabase() common.Database {
|
||||||
|
return h.db
|
||||||
|
}
|
||||||
|
|
||||||
// handlePanic is a helper function to handle panics with stack traces
|
// handlePanic is a helper function to handle panics with stack traces
|
||||||
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
@@ -75,8 +100,14 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Get model and populate context with request-scoped data
|
// Get model and populate context with request-scoped data
|
||||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid entity: %v", err)
|
// Model not found - call fallback handler if set, otherwise pass through
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
||||||
|
if h.fallbackHandler != nil {
|
||||||
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
||||||
|
h.fallbackHandler(w, r, params)
|
||||||
|
} else {
|
||||||
|
logger.Debug("No fallback handler set, passing through to next route")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,6 +151,8 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
|
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
|
||||||
case "delete":
|
case "delete":
|
||||||
h.handleDelete(ctx, w, id, req.Data)
|
h.handleDelete(ctx, w, id, req.Data)
|
||||||
|
case "meta":
|
||||||
|
h.handleMeta(ctx, w, schema, entity, model)
|
||||||
default:
|
default:
|
||||||
logger.Error("Invalid operation: %s", req.Operation)
|
logger.Error("Invalid operation: %s", req.Operation)
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
|
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
|
||||||
@@ -142,8 +175,14 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
|
|
||||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to get model: %v", err)
|
// Model not found - call fallback handler if set, otherwise pass through
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
||||||
|
if h.fallbackHandler != nil {
|
||||||
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
||||||
|
h.fallbackHandler(w, r, params)
|
||||||
|
} else {
|
||||||
|
logger.Debug("No fallback handler set, passing through to next route")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,6 +190,21 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
h.sendResponse(w, metadata, nil)
|
h.sendResponse(w, metadata, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleMeta processes meta operation requests
|
||||||
|
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleMeta", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
|
||||||
|
|
||||||
|
metadata := h.generateMetadata(schema, entity, model)
|
||||||
|
h.sendResponse(w, metadata, nil)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
152
pkg/resolvespec/hooks.go
Normal file
152
pkg/resolvespec/hooks.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package resolvespec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookType defines the type of hook to execute
|
||||||
|
type HookType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Read operation hooks
|
||||||
|
BeforeRead HookType = "before_read"
|
||||||
|
AfterRead HookType = "after_read"
|
||||||
|
|
||||||
|
// Create operation hooks
|
||||||
|
BeforeCreate HookType = "before_create"
|
||||||
|
AfterCreate HookType = "after_create"
|
||||||
|
|
||||||
|
// Update operation hooks
|
||||||
|
BeforeUpdate HookType = "before_update"
|
||||||
|
AfterUpdate HookType = "after_update"
|
||||||
|
|
||||||
|
// Delete operation hooks
|
||||||
|
BeforeDelete HookType = "before_delete"
|
||||||
|
AfterDelete HookType = "after_delete"
|
||||||
|
|
||||||
|
// Scan/Execute operation hooks (for query building)
|
||||||
|
BeforeScan HookType = "before_scan"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookContext contains all the data available to a hook
|
||||||
|
type HookContext struct {
|
||||||
|
Context context.Context
|
||||||
|
Handler *Handler // Reference to the handler for accessing database, registry, etc.
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
Model interface{}
|
||||||
|
Options common.RequestOptions
|
||||||
|
Writer common.ResponseWriter
|
||||||
|
Request common.Request
|
||||||
|
|
||||||
|
// Operation-specific fields
|
||||||
|
ID string
|
||||||
|
Data interface{} // For create/update operations
|
||||||
|
Result interface{} // For after hooks
|
||||||
|
Error error // For after hooks
|
||||||
|
|
||||||
|
// Query chain - allows hooks to modify the query before execution
|
||||||
|
Query common.SelectQuery
|
||||||
|
|
||||||
|
// Allow hooks to abort the operation
|
||||||
|
Abort bool // If set to true, the operation will be aborted
|
||||||
|
AbortMessage string // Message to return if aborted
|
||||||
|
AbortCode int // HTTP status code if aborted
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookFunc is the signature for hook functions
|
||||||
|
// It receives a HookContext and can modify it or return an error
|
||||||
|
// If an error is returned, the operation will be aborted
|
||||||
|
type HookFunc func(*HookContext) error
|
||||||
|
|
||||||
|
// HookRegistry manages all registered hooks
|
||||||
|
type HookRegistry struct {
|
||||||
|
hooks map[HookType][]HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHookRegistry creates a new hook registry
|
||||||
|
func NewHookRegistry() *HookRegistry {
|
||||||
|
return &HookRegistry{
|
||||||
|
hooks: make(map[HookType][]HookFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a new hook for the specified hook type
|
||||||
|
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||||
|
logger.Info("Registered resolvespec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMultiple registers a hook for multiple hook types
|
||||||
|
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
r.Register(hookType, hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs all hooks for the specified type in order
|
||||||
|
// If any hook returns an error, execution stops and the error is returned
|
||||||
|
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
if !exists || len(hooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Executing %d resolvespec hook(s) for %s", len(hooks), hookType)
|
||||||
|
|
||||||
|
for i, hook := range hooks {
|
||||||
|
if err := hook(ctx); err != nil {
|
||||||
|
logger.Error("Resolvespec hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook requested abort
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("Resolvespec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all hooks for the specified type
|
||||||
|
func (r *HookRegistry) Clear(hookType HookType) {
|
||||||
|
delete(r.hooks, hookType)
|
||||||
|
logger.Info("Cleared all resolvespec hooks for %s", hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAll removes all registered hooks
|
||||||
|
func (r *HookRegistry) ClearAll() {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
logger.Info("Cleared all resolvespec hooks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of hooks registered for a specific type
|
||||||
|
func (r *HookRegistry) Count(hookType HookType) int {
|
||||||
|
if hooks, exists := r.hooks[hookType]; exists {
|
||||||
|
return len(hooks)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHooks returns true if there are any hooks registered for the specified type
|
||||||
|
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||||
|
return r.Count(hookType) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllHookTypes returns all hook types that have registered hooks
|
||||||
|
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||||
|
types := make([]HookType, 0, len(r.hooks))
|
||||||
|
for hookType := range r.hooks {
|
||||||
|
types = append(types, hookType)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
@@ -2,12 +2,14 @@ package resolvespec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
"github.com/uptrace/bunrouter"
|
"github.com/uptrace/bunrouter"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
@@ -37,28 +39,122 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
|
|||||||
return router.NewStandardBunRouterAdapter()
|
return router.NewStandardBunRouterAdapter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
|
||||||
|
type MiddlewareFunc func(http.Handler) http.Handler
|
||||||
|
|
||||||
// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux
|
// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux
|
||||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
|
||||||
vars := mux.Vars(r)
|
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
// Get all registered models from the registry
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
allModels := handler.registry.GetAllModels()
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
|
||||||
}).Methods("POST")
|
|
||||||
|
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) {
|
// Loop through each registered model and create explicit routes
|
||||||
vars := mux.Vars(r)
|
for fullName := range allModels {
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
// Parse the full name (e.g., "public.users" or just "users")
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
schema, entity := parseModelName(fullName)
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
|
||||||
}).Methods("POST")
|
|
||||||
|
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
// Build the route paths
|
||||||
vars := mux.Vars(r)
|
entityPath := buildRoutePath(schema, entity)
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||||
|
|
||||||
|
// Create handler functions for this specific entity
|
||||||
|
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
||||||
|
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||||
|
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||||
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||||
|
|
||||||
|
// Apply authentication middleware if provided
|
||||||
|
if authMiddleware != nil {
|
||||||
|
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
||||||
|
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
||||||
|
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register routes for this entity
|
||||||
|
muxRouter.Handle(entityPath, postEntityHandler).Methods("POST")
|
||||||
|
muxRouter.Handle(entityWithIDPath, postEntityWithIDHandler).Methods("POST")
|
||||||
|
muxRouter.Handle(entityPath, getEntityHandler).Methods("GET")
|
||||||
|
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
|
||||||
|
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux handler for a specific entity with CORS support
|
||||||
|
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
if idParam != "" {
|
||||||
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux GET handler for a specific entity with CORS support
|
||||||
|
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
if idParam != "" {
|
||||||
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}).Methods("GET")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux OPTIONS handler that returns metadata
|
||||||
|
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers with the allowed methods for this route
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
corsConfig.AllowedMethods = allowedMethods
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
// Return metadata in the OPTIONS response body
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseModelName parses a model name like "public.users" into schema and entity
|
||||||
|
// If no schema is present, returns empty string for schema
|
||||||
|
func parseModelName(fullName string) (schema, entity string) {
|
||||||
|
parts := strings.Split(fullName, ".")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
return parts[0], parts[1]
|
||||||
|
}
|
||||||
|
return "", fullName
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRoutePath builds a route path from schema and entity
|
||||||
|
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
|
||||||
|
func buildRoutePath(schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return "/" + entity
|
||||||
|
}
|
||||||
|
return "/" + schema + "/" + entity
|
||||||
}
|
}
|
||||||
|
|
||||||
// Example usage functions for documentation:
|
// Example usage functions for documentation:
|
||||||
@@ -68,12 +164,20 @@ func ExampleWithGORM(db *gorm.DB) {
|
|||||||
// Create handler using GORM
|
// Create handler using GORM
|
||||||
handler := NewHandlerWithGORM(db)
|
handler := NewHandlerWithGORM(db)
|
||||||
|
|
||||||
// Setup router
|
// Setup router without authentication
|
||||||
muxRouter := mux.NewRouter()
|
muxRouter := mux.NewRouter()
|
||||||
SetupMuxRoutes(muxRouter, handler)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
|
|
||||||
// Register models
|
// Register models
|
||||||
// handler.RegisterModel("public", "users", &User{})
|
// handler.RegisterModel("public", "users", &User{})
|
||||||
|
|
||||||
|
// To add authentication, pass a middleware function:
|
||||||
|
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
// secList := security.NewSecurityList(myProvider)
|
||||||
|
// authMiddleware := func(h http.Handler) http.Handler {
|
||||||
|
// return security.NewAuthHandler(secList, h)
|
||||||
|
// }
|
||||||
|
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExampleWithBun shows how to switch to Bun ORM
|
// ExampleWithBun shows how to switch to Bun ORM
|
||||||
@@ -88,60 +192,118 @@ func ExampleWithBun(bunDB *bun.DB) {
|
|||||||
// Create handler
|
// Create handler
|
||||||
handler := NewHandler(dbAdapter, registry)
|
handler := NewHandler(dbAdapter, registry)
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes without authentication
|
||||||
muxRouter := mux.NewRouter()
|
muxRouter := mux.NewRouter()
|
||||||
SetupMuxRoutes(muxRouter, handler)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
||||||
r := bunRouter.GetBunRouter()
|
r := bunRouter.GetBunRouter()
|
||||||
|
|
||||||
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// Get all registered models from the registry
|
||||||
params := map[string]string{
|
allModels := handler.registry.GetAllModels()
|
||||||
"schema": req.Param("schema"),
|
|
||||||
"entity": req.Param("entity"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// CORS config
|
||||||
params := map[string]string{
|
corsConfig := common.DefaultCORSConfig()
|
||||||
"schema": req.Param("schema"),
|
|
||||||
"entity": req.Param("entity"),
|
|
||||||
"id": req.Param("id"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// Loop through each registered model and create explicit routes
|
||||||
params := map[string]string{
|
for fullName := range allModels {
|
||||||
"schema": req.Param("schema"),
|
// Parse the full name (e.g., "public.users" or just "users")
|
||||||
"entity": req.Param("entity"),
|
schema, entity := parseModelName(fullName)
|
||||||
}
|
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// Build the route paths
|
||||||
params := map[string]string{
|
entityPath := buildRoutePath(schema, entity)
|
||||||
"schema": req.Param("schema"),
|
entityWithIDPath := entityPath + "/:id"
|
||||||
"entity": req.Param("entity"),
|
|
||||||
"id": req.Param("id"),
|
// Create closure variables to capture current schema and entity
|
||||||
}
|
currentSchema := schema
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
currentEntity := entity
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
// POST route without ID
|
||||||
return nil
|
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
})
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// POST route with ID
|
||||||
|
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// GET route without ID
|
||||||
|
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// GET route with ID
|
||||||
|
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// OPTIONS route without ID (returns metadata)
|
||||||
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
optionsCorsConfig := corsConfig
|
||||||
|
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||||
|
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// OPTIONS route with ID (returns metadata)
|
||||||
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
optionsCorsConfig := corsConfig
|
||||||
|
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
||||||
|
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExampleWithBunRouter shows how to use bunrouter from uptrace
|
// ExampleWithBunRouter shows how to use bunrouter from uptrace
|
||||||
|
|||||||
85
pkg/resolvespec/security_hooks.go
Normal file
85
pkg/resolvespec/security_hooks.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package resolvespec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: BeforeScan - Apply row-level security filters
|
||||||
|
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyRowSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for resolvespec handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// securityContext adapts resolvespec.HookContext to security.SecurityContext interface
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
return s.ctx.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if q, ok := query.(common.SelectQuery); ok {
|
||||||
|
s.ctx.Query = q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -17,6 +17,10 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FallbackHandler is a function that handles requests when no model is found
|
||||||
|
// It receives the same parameters as the Handle method
|
||||||
|
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
|
||||||
|
|
||||||
// Handler handles API requests using database and model abstractions
|
// Handler handles API requests using database and model abstractions
|
||||||
// This handler reads filters, columns, and options from HTTP headers
|
// This handler reads filters, columns, and options from HTTP headers
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
@@ -24,6 +28,7 @@ type Handler struct {
|
|||||||
registry common.ModelRegistry
|
registry common.ModelRegistry
|
||||||
hooks *HookRegistry
|
hooks *HookRegistry
|
||||||
nestedProcessor *common.NestedCUDProcessor
|
nestedProcessor *common.NestedCUDProcessor
|
||||||
|
fallbackHandler FallbackHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new API handler with database and registry abstractions
|
// NewHandler creates a new API handler with database and registry abstractions
|
||||||
@@ -38,12 +43,24 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDatabase returns the underlying database connection
|
||||||
|
// Implements common.SpecHandler interface
|
||||||
|
func (h *Handler) GetDatabase() common.Database {
|
||||||
|
return h.db
|
||||||
|
}
|
||||||
|
|
||||||
// Hooks returns the hook registry for this handler
|
// Hooks returns the hook registry for this handler
|
||||||
// Use this to register custom hooks for operations
|
// Use this to register custom hooks for operations
|
||||||
func (h *Handler) Hooks() *HookRegistry {
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
return h.hooks
|
return h.hooks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFallbackHandler sets a fallback handler to be called when no model is found
|
||||||
|
// If not set, the handler will simply return (pass through to next route)
|
||||||
|
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
|
||||||
|
h.fallbackHandler = fallback
|
||||||
|
}
|
||||||
|
|
||||||
// handlePanic is a helper function to handle panics with stack traces
|
// handlePanic is a helper function to handle panics with stack traces
|
||||||
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
@@ -75,8 +92,14 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Get model and populate context with request-scoped data
|
// Get model and populate context with request-scoped data
|
||||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid entity: %v", err)
|
// Model not found - call fallback handler if set, otherwise pass through
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
||||||
|
if h.fallbackHandler != nil {
|
||||||
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
||||||
|
h.fallbackHandler(w, r, params)
|
||||||
|
} else {
|
||||||
|
logger.Debug("No fallback handler set, passing through to next route")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,13 +146,25 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
h.handleRead(ctx, w, "", options)
|
h.handleRead(ctx, w, "", options)
|
||||||
}
|
}
|
||||||
case "POST":
|
case "POST":
|
||||||
// Create operation
|
// Read request body
|
||||||
body, err := r.Body()
|
body, err := r.Body()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to read request body: %v", err)
|
logger.Error("Failed to read request body: %v", err)
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err)
|
h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to detect if this is a meta operation request
|
||||||
|
var bodyMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &bodyMap); err == nil {
|
||||||
|
if operation, ok := bodyMap["operation"].(string); ok && operation == "meta" {
|
||||||
|
logger.Info("Detected meta operation request for %s.%s", schema, entity)
|
||||||
|
h.handleMeta(ctx, w, schema, entity, model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a meta operation, proceed with normal create/update
|
||||||
var data interface{}
|
var data interface{}
|
||||||
if err := json.Unmarshal(body, &data); err != nil {
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
logger.Error("Failed to decode request body: %v", err)
|
logger.Error("Failed to decode request body: %v", err)
|
||||||
@@ -191,8 +226,14 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
|
|
||||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to get model: %v", err)
|
// Model not found - call fallback handler if set, otherwise pass through
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
||||||
|
if h.fallbackHandler != nil {
|
||||||
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
||||||
|
h.fallbackHandler(w, r, params)
|
||||||
|
} else {
|
||||||
|
logger.Debug("No fallback handler set, passing through to next route")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,6 +241,21 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
h.sendResponse(w, metadata, nil)
|
h.sendResponse(w, metadata, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleMeta processes meta operation requests
|
||||||
|
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleMeta", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
|
||||||
|
|
||||||
|
metadata := h.generateMetadata(schema, entity, model)
|
||||||
|
h.sendResponse(w, metadata, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// parseOptionsFromHeaders is now implemented in headers.go
|
// parseOptionsFromHeaders is now implemented in headers.go
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
||||||
@@ -277,7 +333,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if len(options.ComputedQL) > 0 {
|
if len(options.ComputedQL) > 0 {
|
||||||
for colName, colExpr := range options.ComputedQL {
|
for colName, colExpr := range options.ComputedQL {
|
||||||
logger.Debug("Applying computed column: %s", colName)
|
logger.Debug("Applying computed column: %s", colName)
|
||||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
if strings.Contains(colName, "cql") {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s)::text AS %s", colExpr, colName))
|
||||||
|
} else {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s)AS %s", colExpr, colName))
|
||||||
|
}
|
||||||
|
|
||||||
for colIndex := range options.Columns {
|
for colIndex := range options.Columns {
|
||||||
if options.Columns[colIndex] == colName {
|
if options.Columns[colIndex] == colName {
|
||||||
// Remove the computed column from the selected columns to avoid duplication
|
// Remove the computed column from the selected columns to avoid duplication
|
||||||
@@ -291,7 +352,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if len(options.ComputedColumns) > 0 {
|
if len(options.ComputedColumns) > 0 {
|
||||||
for _, cu := range options.ComputedColumns {
|
for _, cu := range options.ComputedColumns {
|
||||||
logger.Debug("Applying computed column: %s", cu.Name)
|
logger.Debug("Applying computed column: %s", cu.Name)
|
||||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
if strings.Contains(cu.Name, "cql") {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s)::text AS %s", cu.Expression, cu.Name))
|
||||||
|
} else {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||||
|
}
|
||||||
|
|
||||||
for colIndex := range options.Columns {
|
for colIndex := range options.Columns {
|
||||||
if options.Columns[colIndex] == cu.Name {
|
if options.Columns[colIndex] == cu.Name {
|
||||||
// Remove the computed column from the selected columns to avoid duplication
|
// Remove the computed column from the selected columns to avoid duplication
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package restheadspec
|
package restheadspec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,6 +43,12 @@ func (m *MockRequest) AllQueryParams() map[string]string {
|
|||||||
return m.queryParams
|
return m.queryParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) UnderlyingRequest() *http.Request {
|
||||||
|
// For testing purposes, return nil
|
||||||
|
// In real scenarios, you might want to construct a proper http.Request
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseOptionsFromQueryParams(t *testing.T) {
|
func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||||
handler := NewHandler(nil, nil)
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -54,12 +54,14 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
"github.com/uptrace/bunrouter"
|
"github.com/uptrace/bunrouter"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -90,31 +92,130 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
|
|||||||
return router.NewStandardBunRouterAdapter()
|
return router.NewStandardBunRouterAdapter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
|
||||||
|
type MiddlewareFunc func(http.Handler) http.Handler
|
||||||
|
|
||||||
// SetupMuxRoutes sets up routes for the RestHeadSpec API with Mux
|
// SetupMuxRoutes sets up routes for the RestHeadSpec API with Mux
|
||||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}
|
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||||
vars := mux.Vars(r)
|
// Get all registered models from the registry
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
allModels := handler.registry.GetAllModels()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
|
||||||
}).Methods("GET", "POST")
|
|
||||||
|
|
||||||
// GET, PUT, PATCH, DELETE for /{schema}/{entity}/{id}
|
// Loop through each registered model and create explicit routes
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) {
|
for fullName := range allModels {
|
||||||
vars := mux.Vars(r)
|
// Parse the full name (e.g., "public.users" or just "users")
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
schema, entity := parseModelName(fullName)
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
|
||||||
}).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
|
|
||||||
|
|
||||||
// GET for metadata (using HandleGet)
|
// Build the route paths
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
|
entityPath := buildRoutePath(schema, entity)
|
||||||
vars := mux.Vars(r)
|
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||||
reqAdapter := router.NewHTTPRequest(r)
|
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||||
|
|
||||||
|
// Create handler functions for this specific entity
|
||||||
|
entityHandler := createMuxHandler(handler, schema, entity, "")
|
||||||
|
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||||
|
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||||
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||||
|
|
||||||
|
// Apply authentication middleware if provided
|
||||||
|
if authMiddleware != nil {
|
||||||
|
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
||||||
|
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
||||||
|
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register routes for this entity
|
||||||
|
// GET, POST for /{schema}/{entity}
|
||||||
|
muxRouter.Handle(entityPath, entityHandler).Methods("GET", "POST")
|
||||||
|
|
||||||
|
// GET, PUT, PATCH, DELETE, POST for /{schema}/{entity}/{id}
|
||||||
|
muxRouter.Handle(entityWithIDPath, entityWithIDHandler).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
|
||||||
|
|
||||||
|
// GET for metadata (using HandleGet)
|
||||||
|
muxRouter.Handle(metadataPath, metadataHandler).Methods("GET")
|
||||||
|
|
||||||
|
// OPTIONS for CORS preflight - returns metadata
|
||||||
|
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
|
||||||
|
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux handler for a specific entity with CORS support
|
||||||
|
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
if idParam != "" {
|
||||||
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux GET handler for a specific entity with CORS support
|
||||||
|
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
if idParam != "" {
|
||||||
|
vars["id"] = mux.Vars(r)[idParam]
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
}).Methods("GET")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create Mux OPTIONS handler that returns metadata
|
||||||
|
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Set CORS headers with the allowed methods for this route
|
||||||
|
corsConfig := common.DefaultCORSConfig()
|
||||||
|
corsConfig.AllowedMethods = allowedMethods
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
|
||||||
|
// Return metadata in the OPTIONS response body
|
||||||
|
vars := make(map[string]string)
|
||||||
|
vars["schema"] = schema
|
||||||
|
vars["entity"] = entity
|
||||||
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseModelName parses a model name like "public.users" into schema and entity
|
||||||
|
// If no schema is present, returns empty string for schema
|
||||||
|
func parseModelName(fullName string) (schema, entity string) {
|
||||||
|
parts := strings.Split(fullName, ".")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
return parts[0], parts[1]
|
||||||
|
}
|
||||||
|
return "", fullName
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRoutePath builds a route path from schema and entity
|
||||||
|
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
|
||||||
|
func buildRoutePath(schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return "/" + entity
|
||||||
|
}
|
||||||
|
return "/" + schema + "/" + entity
|
||||||
}
|
}
|
||||||
|
|
||||||
// Example usage functions for documentation:
|
// Example usage functions for documentation:
|
||||||
@@ -124,12 +225,20 @@ func ExampleWithGORM(db *gorm.DB) {
|
|||||||
// Create handler using GORM
|
// Create handler using GORM
|
||||||
handler := NewHandlerWithGORM(db)
|
handler := NewHandlerWithGORM(db)
|
||||||
|
|
||||||
// Setup router
|
// Setup router without authentication
|
||||||
muxRouter := mux.NewRouter()
|
muxRouter := mux.NewRouter()
|
||||||
SetupMuxRoutes(muxRouter, handler)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
|
|
||||||
// Register models
|
// Register models
|
||||||
// handler.registry.RegisterModel("public.users", &User{})
|
// handler.registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
// To add authentication, pass a middleware function:
|
||||||
|
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
// secList := security.NewSecurityList(myProvider)
|
||||||
|
// authMiddleware := func(h http.Handler) http.Handler {
|
||||||
|
// return security.NewAuthHandler(secList, h)
|
||||||
|
// }
|
||||||
|
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExampleWithBun shows how to switch to Bun ORM
|
// ExampleWithBun shows how to switch to Bun ORM
|
||||||
@@ -144,110 +253,169 @@ func ExampleWithBun(bunDB *bun.DB) {
|
|||||||
// Create handler
|
// Create handler
|
||||||
handler := NewHandler(dbAdapter, registry)
|
handler := NewHandler(dbAdapter, registry)
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes without authentication
|
||||||
muxRouter := mux.NewRouter()
|
muxRouter := mux.NewRouter()
|
||||||
SetupMuxRoutes(muxRouter, handler)
|
SetupMuxRoutes(muxRouter, handler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
||||||
r := bunRouter.GetBunRouter()
|
r := bunRouter.GetBunRouter()
|
||||||
|
|
||||||
// GET and POST for /:schema/:entity
|
// Get all registered models from the registry
|
||||||
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
allModels := handler.registry.GetAllModels()
|
||||||
params := map[string]string{
|
|
||||||
"schema": req.Param("schema"),
|
|
||||||
"entity": req.Param("entity"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// CORS config
|
||||||
params := map[string]string{
|
corsConfig := common.DefaultCORSConfig()
|
||||||
"schema": req.Param("schema"),
|
|
||||||
"entity": req.Param("entity"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// GET, PUT, PATCH, DELETE for /:schema/:entity/:id
|
// Loop through each registered model and create explicit routes
|
||||||
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
for fullName := range allModels {
|
||||||
params := map[string]string{
|
// Parse the full name (e.g., "public.users" or just "users")
|
||||||
"schema": req.Param("schema"),
|
schema, entity := parseModelName(fullName)
|
||||||
"entity": req.Param("entity"),
|
|
||||||
"id": req.Param("id"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// Build the route paths
|
||||||
params := map[string]string{
|
entityPath := buildRoutePath(schema, entity)
|
||||||
"schema": req.Param("schema"),
|
entityWithIDPath := entityPath + "/:id"
|
||||||
"entity": req.Param("entity"),
|
metadataPath := entityPath + "/metadata"
|
||||||
"id": req.Param("id"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// Create closure variables to capture current schema and entity
|
||||||
params := map[string]string{
|
currentSchema := schema
|
||||||
"schema": req.Param("schema"),
|
currentEntity := entity
|
||||||
"entity": req.Param("entity"),
|
|
||||||
"id": req.Param("id"),
|
|
||||||
}
|
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Handle("PATCH", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
// GET and POST for /{schema}/{entity}
|
||||||
params := map[string]string{
|
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
"schema": req.Param("schema"),
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
"entity": req.Param("entity"),
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
"id": req.Param("id"),
|
params := map[string]string{
|
||||||
}
|
"schema": currentSchema,
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
"entity": currentEntity,
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
}
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
return nil
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
})
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
r.Handle("DELETE", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
params := map[string]string{
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
"schema": req.Param("schema"),
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
"entity": req.Param("entity"),
|
params := map[string]string{
|
||||||
"id": req.Param("id"),
|
"schema": currentSchema,
|
||||||
}
|
"entity": currentEntity,
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
}
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// Metadata endpoint
|
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||||
r.Handle("GET", "/:schema/:entity/metadata", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
params := map[string]string{
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
"schema": req.Param("schema"),
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
"entity": req.Param("entity"),
|
params := map[string]string{
|
||||||
}
|
"schema": currentSchema,
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
"entity": currentEntity,
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
"id": req.Param("id"),
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
}
|
||||||
return nil
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
})
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Metadata endpoint
|
||||||
|
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// OPTIONS route without ID (returns metadata)
|
||||||
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
optionsCorsConfig := corsConfig
|
||||||
|
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||||
|
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// OPTIONS route with ID (returns metadata)
|
||||||
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
optionsCorsConfig := corsConfig
|
||||||
|
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
|
||||||
|
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": currentSchema,
|
||||||
|
"entity": currentEntity,
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExampleBunRouterWithBunDB shows usage with both BunRouter and Bun DB
|
// ExampleBunRouterWithBunDB shows usage with both BunRouter and Bun DB
|
||||||
|
|||||||
82
pkg/restheadspec/security_hooks.go
Normal file
82
pkg/restheadspec/security_hooks.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: BeforeScan - Apply row-level security filters
|
||||||
|
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyRowSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for restheadspec handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// securityContext adapts restheadspec.HookContext to security.SecurityContext interface
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
return s.ctx.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
s.ctx.Query = query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -91,7 +91,8 @@ security.UserContext{
|
|||||||
RemoteID: "remote_xyz", // Remote system ID
|
RemoteID: "remote_xyz", // Remote system ID
|
||||||
Roles: []string{"admin"}, // User roles
|
Roles: []string{"admin"}, // User roles
|
||||||
Email: "john@example.com", // User email
|
Email: "john@example.com", // User email
|
||||||
Claims: map[string]any{}, // Additional metadata
|
Claims: map[string]any{}, // Additional authentication claims
|
||||||
|
Meta: map[string]any{}, // Additional metadata (JSON-serializable)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -621,6 +622,67 @@ func main() {
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Authentication Modes
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Required authentication (default)
|
||||||
|
// Authentication must succeed or returns 401
|
||||||
|
router.Use(security.NewAuthMiddleware(securityList))
|
||||||
|
|
||||||
|
// Skip authentication for specific routes
|
||||||
|
// Always sets guest user context
|
||||||
|
func PublicRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.SkipAuth(r.Context())
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
// Guest context will be set
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional authentication for specific routes
|
||||||
|
// Tries to authenticate, falls back to guest if it fails
|
||||||
|
func HomeRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.OptionalAuth(r.Context())
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
if userCtx.UserID == 0 {
|
||||||
|
// Guest user
|
||||||
|
} else {
|
||||||
|
// Authenticated user
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Comparison:**
|
||||||
|
- **Required**: Auth must succeed or return 401 (default)
|
||||||
|
- **SkipAuth**: Never tries to authenticate, always guest
|
||||||
|
- **OptionalAuth**: Tries to authenticate, guest on failure
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Standalone Handlers
|
||||||
|
|
||||||
|
```go
|
||||||
|
// NewAuthHandler - Required authentication (returns 401 on failure)
|
||||||
|
authHandler := security.NewAuthHandler(securityList, myHandler)
|
||||||
|
http.Handle("/api/protected", authHandler)
|
||||||
|
|
||||||
|
// NewOptionalAuthHandler - Optional authentication (guest on failure)
|
||||||
|
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||||
|
http.Handle("/home", optionalHandler)
|
||||||
|
|
||||||
|
// Example handler
|
||||||
|
func myHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
if userCtx.UserID == 0 {
|
||||||
|
// Guest user
|
||||||
|
} else {
|
||||||
|
// Authenticated user
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Context Helpers
|
## Context Helpers
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -635,6 +697,7 @@ sessionID, ok := security.GetSessionID(ctx)
|
|||||||
remoteID, ok := security.GetRemoteID(ctx)
|
remoteID, ok := security.GetRemoteID(ctx)
|
||||||
roles, ok := security.GetUserRoles(ctx)
|
roles, ok := security.GetUserRoles(ctx)
|
||||||
email, ok := security.GetUserEmail(ctx)
|
email, ok := security.GetUserEmail(ctx)
|
||||||
|
meta, ok := security.GetUserMeta(ctx)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -56,9 +56,10 @@ rowSec := security.NewDatabaseRowSecurityProvider(db)
|
|||||||
// 2. Combine providers
|
// 2. Combine providers
|
||||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
|
||||||
// 3. Setup security
|
// 3. Create handler and register security hooks
|
||||||
handler := restheadspec.NewHandlerWithGORM(db)
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList := security.NewSecurityList(provider)
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
// 4. Apply middleware
|
// 4. Apply middleware
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
@@ -69,6 +70,38 @@ router.Use(security.SetSecurityMiddleware(securityList))
|
|||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
|
### Spec-Agnostic Design
|
||||||
|
|
||||||
|
The security system is **completely spec-agnostic** - it doesn't depend on any specific spec implementation. Instead, each spec (restheadspec, funcspec, resolvespec) implements its own security integration by adapting to the `SecurityContext` interface.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Security Package (Generic) │
|
||||||
|
│ - SecurityContext interface │
|
||||||
|
│ - Security providers │
|
||||||
|
│ - Core security logic │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
▲ ▲ ▲
|
||||||
|
│ │ │
|
||||||
|
┌──────┘ │ └──────┐
|
||||||
|
│ │ │
|
||||||
|
┌───▼────┐ ┌────▼─────┐ ┌────▼──────┐
|
||||||
|
│RestHead│ │ FuncSpec │ │ResolveSpec│
|
||||||
|
│ Spec │ │ │ │ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│Adapts │ │ Adapts │ │ Adapts │
|
||||||
|
│to │ │ to │ │ to │
|
||||||
|
│Security│ │ Security │ │ Security │
|
||||||
|
│Context │ │ Context │ │ Context │
|
||||||
|
└────────┘ └──────────┘ └───────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- ✅ No circular dependencies
|
||||||
|
- ✅ Each spec can customize security integration
|
||||||
|
- ✅ Easy to add new specs
|
||||||
|
- ✅ Security logic is reusable across all specs
|
||||||
|
|
||||||
### Core Interfaces
|
### Core Interfaces
|
||||||
|
|
||||||
The security system is built on three main interfaces:
|
The security system is built on three main interfaces:
|
||||||
@@ -113,19 +146,42 @@ type SecurityProvider interface {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 4. SecurityContext (Spec Integration Interface)
|
||||||
|
Each spec implements this interface to integrate with the security system:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type SecurityContext interface {
|
||||||
|
GetContext() context.Context
|
||||||
|
GetUserID() (int, bool)
|
||||||
|
GetSchema() string
|
||||||
|
GetEntity() string
|
||||||
|
GetModel() interface{}
|
||||||
|
GetQuery() interface{}
|
||||||
|
SetQuery(interface{})
|
||||||
|
GetResult() interface{}
|
||||||
|
SetResult(interface{})
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Implementation Examples:**
|
||||||
|
- `restheadspec`: Adapts `restheadspec.HookContext` → `SecurityContext`
|
||||||
|
- `funcspec`: Adapts `funcspec.HookContext` → `SecurityContext`
|
||||||
|
- `resolvespec`: Adapts `resolvespec.HookContext` → `SecurityContext`
|
||||||
|
|
||||||
### UserContext
|
### UserContext
|
||||||
Enhanced user context with complete user information:
|
Enhanced user context with complete user information:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
UserID int // User's unique ID
|
UserID int // User's unique ID
|
||||||
UserName string // Username
|
UserName string // Username
|
||||||
UserLevel int // User privilege level
|
UserLevel int // User privilege level
|
||||||
SessionID string // Current session ID
|
SessionID string // Current session ID
|
||||||
RemoteID string // Remote system ID
|
RemoteID string // Remote system ID
|
||||||
Roles []string // User roles
|
Roles []string // User roles
|
||||||
Email string // User email
|
Email string // User email
|
||||||
Claims map[string]any // Additional metadata
|
Claims map[string]any // Additional authentication claims
|
||||||
|
Meta map[string]any // Additional metadata (can hold any JSON-serializable values)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -196,7 +252,7 @@ rowSec := security.NewConfigRowSecurityProvider(templates, blocked)
|
|||||||
|
|
||||||
## Usage Examples
|
## Usage Examples
|
||||||
|
|
||||||
### Example 1: Complete Database-Backed Security with Sessions
|
### Example 1: Complete Database-Backed Security with Sessions (restheadspec)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
func main() {
|
func main() {
|
||||||
@@ -206,16 +262,20 @@ func main() {
|
|||||||
// db.Exec("CREATE TABLE users ...")
|
// db.Exec("CREATE TABLE users ...")
|
||||||
// db.Exec("CREATE TABLE user_sessions ...")
|
// db.Exec("CREATE TABLE user_sessions ...")
|
||||||
|
|
||||||
|
// Create handler
|
||||||
handler := restheadspec.NewHandlerWithGORM(db)
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
// Create providers
|
// Create security providers
|
||||||
auth := security.NewDatabaseAuthenticator(db) // Session-based auth
|
auth := security.NewDatabaseAuthenticator(db) // Session-based auth
|
||||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||||
|
|
||||||
// Combine
|
// Combine providers
|
||||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Register security hooks for this spec
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
@@ -308,14 +368,85 @@ func main() {
|
|||||||
colSec := security.NewConfigColumnSecurityProvider(columnRules)
|
colSec := security.NewConfigColumnSecurityProvider(columnRules)
|
||||||
rowSec := security.NewConfigRowSecurityProvider(rowTemplates, nil)
|
rowSec := security.NewConfigRowSecurityProvider(rowTemplates, nil)
|
||||||
|
|
||||||
|
// Combine providers and register hooks
|
||||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList := security.NewSecurityList(provider)
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
// Setup routes...
|
// Setup routes...
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example 3: Custom Provider
|
### Example 3: FuncSpec Security (SQL Query API)
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/funcspec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db := setupDatabase()
|
||||||
|
|
||||||
|
// Create funcspec handler
|
||||||
|
handler := funcspec.NewHandler(db)
|
||||||
|
|
||||||
|
// Create security providers
|
||||||
|
auth := security.NewJWTAuthenticator("secret-key", db)
|
||||||
|
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||||
|
|
||||||
|
// Combine providers
|
||||||
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Register security hooks (audit logging)
|
||||||
|
funcspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
|
// Note: funcspec operates on raw SQL queries, so row/column
|
||||||
|
// security is limited. Security should be enforced at the
|
||||||
|
// SQL function level or via database policies.
|
||||||
|
|
||||||
|
// Setup routes...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example 4: ResolveSpec Security (REST API)
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db := setupDatabase()
|
||||||
|
registry := common.NewModelRegistry()
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
registry.RegisterModel("public.users", &User{})
|
||||||
|
registry.RegisterModel("public.orders", &Order{})
|
||||||
|
|
||||||
|
// Create resolvespec handler
|
||||||
|
handler := resolvespec.NewHandler(db, registry)
|
||||||
|
|
||||||
|
// Create security providers
|
||||||
|
auth := security.NewDatabaseAuthenticator(db)
|
||||||
|
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||||
|
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||||
|
|
||||||
|
// Combine providers
|
||||||
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Register security hooks for resolvespec
|
||||||
|
resolvespec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
|
// Setup routes...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example 5: Custom Provider
|
||||||
|
|
||||||
Implement your own provider for complete control:
|
Implement your own provider for complete control:
|
||||||
|
|
||||||
@@ -344,9 +475,18 @@ func (p *MySecurityProvider) GetRowSecurity(ctx context.Context, userID int, sch
|
|||||||
// Your custom row security logic
|
// Your custom row security logic
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use it
|
// Use it with any spec
|
||||||
provider := &MySecurityProvider{db: db}
|
provider := &MySecurityProvider{db: db}
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// Register with restheadspec
|
||||||
|
restheadspec.RegisterSecurityHooks(restHandler, securityList)
|
||||||
|
|
||||||
|
// Or with funcspec
|
||||||
|
funcspec.RegisterSecurityHooks(funcHandler, securityList)
|
||||||
|
|
||||||
|
// Or with resolvespec
|
||||||
|
resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Security Features
|
## Security Features
|
||||||
@@ -418,30 +558,45 @@ securityList := security.SetupSecurityProvider(handler, provider)
|
|||||||
```
|
```
|
||||||
HTTP Request
|
HTTP Request
|
||||||
↓
|
↓
|
||||||
NewAuthMiddleware
|
NewAuthMiddleware (security package)
|
||||||
├─ Calls provider.Authenticate(request)
|
├─ Calls provider.Authenticate(request)
|
||||||
└─ Adds UserContext to context
|
└─ Adds UserContext to context
|
||||||
↓
|
↓
|
||||||
SetSecurityMiddleware
|
SetSecurityMiddleware (security package)
|
||||||
└─ Adds SecurityList to context
|
└─ Adds SecurityList to context
|
||||||
↓
|
↓
|
||||||
Handler.Handle()
|
Spec Handler (restheadspec/funcspec/resolvespec)
|
||||||
↓
|
↓
|
||||||
BeforeRead Hook
|
BeforeRead Hook (registered by spec)
|
||||||
├─ Calls provider.GetColumnSecurity()
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
└─ Calls provider.GetRowSecurity()
|
├─ Calls security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
│ ├─ Calls provider.GetColumnSecurity()
|
||||||
|
│ └─ Calls provider.GetRowSecurity()
|
||||||
|
└─ Caches security rules
|
||||||
↓
|
↓
|
||||||
BeforeScan Hook
|
BeforeScan Hook (registered by spec)
|
||||||
└─ Applies row security (adds WHERE clause)
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
|
├─ Calls security.ApplyRowSecurity(secCtx, securityList)
|
||||||
|
└─ Applies row security (adds WHERE clause to query)
|
||||||
↓
|
↓
|
||||||
Database Query (with security filters)
|
Database Query (with security filters)
|
||||||
↓
|
↓
|
||||||
AfterRead Hook
|
AfterRead Hook (registered by spec)
|
||||||
└─ Applies column security (masks/hides fields)
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
|
├─ Calls security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
├─ Applies column security (masks/hides fields)
|
||||||
|
└─ Calls security.LogDataAccess(secCtx)
|
||||||
↓
|
↓
|
||||||
HTTP Response (secured data)
|
HTTP Response (secured data)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Key Points:**
|
||||||
|
- Security package is spec-agnostic and provides core logic
|
||||||
|
- Each spec registers its own hooks that adapt to SecurityContext
|
||||||
|
- Security rules are loaded once and cached for the request
|
||||||
|
- Row security is applied to the query (database level)
|
||||||
|
- Column security is applied to results (application level)
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
The interface-based design makes testing straightforward:
|
The interface-based design makes testing straightforward:
|
||||||
@@ -474,7 +629,9 @@ func TestMyHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Migration from Callbacks
|
## Migration Guide
|
||||||
|
|
||||||
|
### From Old Callback System
|
||||||
|
|
||||||
If you're upgrading from the old callback-based system:
|
If you're upgrading from the old callback-based system:
|
||||||
|
|
||||||
@@ -488,7 +645,7 @@ security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
|||||||
|
|
||||||
**New:**
|
**New:**
|
||||||
```go
|
```go
|
||||||
// Wrap your functions in a provider
|
// 1. Wrap your functions in a provider
|
||||||
type MyProvider struct{}
|
type MyProvider struct{}
|
||||||
|
|
||||||
func (p *MyProvider) Authenticate(r *http.Request) (*security.UserContext, error) {
|
func (p *MyProvider) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||||
@@ -512,11 +669,34 @@ func (p *MyProvider) Logout(ctx context.Context, req security.LogoutRequest) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use it
|
// 2. Create security list and register hooks
|
||||||
provider := &MyProvider{}
|
provider := &MyProvider{}
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
|
||||||
|
// 3. Register with your spec
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
```
|
||||||
|
|
||||||
|
### From Old SetupSecurityProvider API
|
||||||
|
|
||||||
|
If you're upgrading from the previous interface-based system:
|
||||||
|
|
||||||
|
**Old:**
|
||||||
|
```go
|
||||||
securityList := security.SetupSecurityProvider(handler, provider)
|
securityList := security.SetupSecurityProvider(handler, provider)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**New:**
|
||||||
|
```go
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||||
|
```
|
||||||
|
|
||||||
|
The main changes:
|
||||||
|
1. Security package no longer knows about specific spec types
|
||||||
|
2. Each spec registers its own security hooks
|
||||||
|
3. More flexible - same security provider works with all specs
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
| File | Description |
|
| File | Description |
|
||||||
@@ -629,6 +809,142 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Middleware and Handler API
|
||||||
|
|
||||||
|
### NewAuthMiddleware
|
||||||
|
Standard middleware that authenticates all requests:
|
||||||
|
|
||||||
|
```go
|
||||||
|
router.Use(security.NewAuthMiddleware(securityList))
|
||||||
|
```
|
||||||
|
|
||||||
|
Routes can skip authentication using the `SkipAuth` helper:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func PublicHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.SkipAuth(r.Context())
|
||||||
|
// This route will bypass authentication
|
||||||
|
// A guest user context will be set instead
|
||||||
|
}
|
||||||
|
|
||||||
|
router.Handle("/public", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.SkipAuth(r.Context())
|
||||||
|
PublicHandler(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
When authentication is skipped, a guest user context is automatically set:
|
||||||
|
- UserID: 0
|
||||||
|
- UserName: "guest"
|
||||||
|
- Roles: ["guest"]
|
||||||
|
- RemoteID: Request's remote address
|
||||||
|
|
||||||
|
Routes can use optional authentication with the `OptionalAuth` helper:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func OptionalAuthHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.OptionalAuth(r.Context())
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
// This route will try to authenticate
|
||||||
|
// If authentication succeeds, authenticated user context is set
|
||||||
|
// If authentication fails, guest user context is set instead
|
||||||
|
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
if userCtx.UserID == 0 {
|
||||||
|
// Guest user
|
||||||
|
fmt.Fprintf(w, "Welcome, guest!")
|
||||||
|
} else {
|
||||||
|
// Authenticated user
|
||||||
|
fmt.Fprintf(w, "Welcome back, %s!", userCtx.UserName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
router.Handle("/home", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := security.OptionalAuth(r.Context())
|
||||||
|
OptionalAuthHandler(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Authentication Modes Summary:**
|
||||||
|
- **Required (default)**: Authentication must succeed or returns 401
|
||||||
|
- **SkipAuth**: Bypasses authentication entirely, always sets guest context
|
||||||
|
- **OptionalAuth**: Tries authentication, falls back to guest context if it fails
|
||||||
|
|
||||||
|
### NewAuthHandler
|
||||||
|
|
||||||
|
Standalone authentication handler (without middleware wrapping):
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use when you need authentication logic without middleware
|
||||||
|
authHandler := security.NewAuthHandler(securityList, myHandler)
|
||||||
|
http.Handle("/api/protected", authHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### NewOptionalAuthHandler
|
||||||
|
|
||||||
|
Standalone optional authentication handler that tries to authenticate but falls back to guest:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use for routes that should work for both authenticated and guest users
|
||||||
|
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||||
|
http.Handle("/home", optionalHandler)
|
||||||
|
|
||||||
|
// Example handler that checks user context
|
||||||
|
func myHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCtx, _ := security.GetUserContext(r.Context())
|
||||||
|
if userCtx.UserID == 0 {
|
||||||
|
fmt.Fprintf(w, "Welcome, guest!")
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "Welcome back, %s!", userCtx.UserName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Helper Functions
|
||||||
|
|
||||||
|
Extract user information from context:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get full user context
|
||||||
|
userCtx, ok := security.GetUserContext(ctx)
|
||||||
|
|
||||||
|
// Get specific fields
|
||||||
|
userID, ok := security.GetUserID(ctx)
|
||||||
|
userName, ok := security.GetUserName(ctx)
|
||||||
|
userLevel, ok := security.GetUserLevel(ctx)
|
||||||
|
sessionID, ok := security.GetSessionID(ctx)
|
||||||
|
remoteID, ok := security.GetRemoteID(ctx)
|
||||||
|
roles, ok := security.GetUserRoles(ctx)
|
||||||
|
email, ok := security.GetUserEmail(ctx)
|
||||||
|
meta, ok := security.GetUserMeta(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Metadata Support
|
||||||
|
|
||||||
|
The `Meta` field in `UserContext` can hold any JSON-serializable values:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Set metadata during login
|
||||||
|
loginReq := security.LoginRequest{
|
||||||
|
Username: "user@example.com",
|
||||||
|
Password: "password",
|
||||||
|
Meta: map[string]any{
|
||||||
|
"department": "engineering",
|
||||||
|
"location": "US",
|
||||||
|
"preferences": map[string]any{
|
||||||
|
"theme": "dark",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access metadata in handlers
|
||||||
|
meta, ok := security.GetUserMeta(ctx)
|
||||||
|
if ok {
|
||||||
|
department := meta["department"].(string)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Part of the ResolveSpec project.
|
Part of the ResolveSpec project.
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ func (a *HeaderAuthenticatorExample) Authenticate(r *http.Request) (*UserContext
|
|||||||
RemoteID: r.Header.Get("X-Remote-ID"),
|
RemoteID: r.Header.Get("X-Remote-ID"),
|
||||||
Email: r.Header.Get("X-User-Email"),
|
Email: r.Header.Get("X-User-Email"),
|
||||||
Roles: parseRoles(r.Header.Get("X-User-Roles")),
|
Roles: parseRoles(r.Header.Get("X-User-Roles")),
|
||||||
|
Claims: make(map[string]any),
|
||||||
|
Meta: make(map[string]any),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +127,8 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
|||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
UserLevel: user.UserLevel,
|
UserLevel: user.UserLevel,
|
||||||
Roles: parseRoles(user.Roles),
|
Roles: parseRoles(user.Roles),
|
||||||
|
Claims: req.Claims,
|
||||||
|
Meta: req.Meta,
|
||||||
},
|
},
|
||||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -242,6 +246,9 @@ func (a *DatabaseAuthenticatorExample) Login(ctx context.Context, req LoginReque
|
|||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
UserLevel: user.UserLevel,
|
UserLevel: user.UserLevel,
|
||||||
Roles: parseRoles(user.Roles),
|
Roles: parseRoles(user.Roles),
|
||||||
|
SessionID: sessionToken,
|
||||||
|
Claims: req.Claims,
|
||||||
|
Meta: req.Meta,
|
||||||
},
|
},
|
||||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||||
}, nil
|
}, nil
|
||||||
@@ -320,6 +327,8 @@ func (a *DatabaseAuthenticatorExample) Authenticate(r *http.Request) (*UserConte
|
|||||||
UserLevel: session.UserLevel,
|
UserLevel: session.UserLevel,
|
||||||
SessionID: sessionToken,
|
SessionID: sessionToken,
|
||||||
Roles: parseRoles(session.Roles),
|
Roles: parseRoles(session.Roles),
|
||||||
|
Claims: make(map[string]any),
|
||||||
|
Meta: make(map[string]any),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,9 +379,12 @@ func (a *DatabaseAuthenticatorExample) RefreshToken(ctx context.Context, refresh
|
|||||||
return &LoginResponse{
|
return &LoginResponse{
|
||||||
Token: newSessionToken,
|
Token: newSessionToken,
|
||||||
User: &UserContext{
|
User: &UserContext{
|
||||||
UserID: session.UserID,
|
UserID: session.UserID,
|
||||||
UserName: session.Username,
|
UserName: session.Username,
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
|
SessionID: newSessionToken,
|
||||||
|
Claims: make(map[string]any),
|
||||||
|
Meta: make(map[string]any),
|
||||||
},
|
},
|
||||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -1,51 +1,43 @@
|
|||||||
package security
|
package security
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||||
func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) {
|
// This interface abstracts the common security context needs across different specs
|
||||||
|
type SecurityContext interface {
|
||||||
// Hook 1: BeforeRead - Load security rules
|
GetContext() context.Context
|
||||||
handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error {
|
GetUserID() (int, bool)
|
||||||
return LoadSecurityRules(hookCtx, securityList)
|
GetSchema() string
|
||||||
})
|
GetEntity() string
|
||||||
|
GetModel() interface{}
|
||||||
// Hook 2: BeforeScan - Apply row-level security filters
|
GetQuery() interface{}
|
||||||
handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error {
|
SetQuery(interface{})
|
||||||
return ApplyRowSecurity(hookCtx, securityList)
|
GetResult() interface{}
|
||||||
})
|
SetResult(interface{})
|
||||||
|
|
||||||
// Hook 3: AfterRead - Apply column-level security (masking)
|
|
||||||
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
|
|
||||||
return ApplyColumnSecurity(hookCtx, securityList)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Hook 4 (Optional): Audit logging
|
|
||||||
handler.Hooks().Register(restheadspec.AfterRead, LogDataAccess)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSecurityRules loads security configuration for the user and entity
|
// loadSecurityRules loads security configuration for the user and entity (generic version)
|
||||||
func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
func loadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
// Extract user ID from context
|
// Extract user ID from context
|
||||||
userID, ok := GetUserID(hookCtx.Context)
|
userID, ok := secCtx.GetUserID()
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Warn("No user ID in context for security check")
|
logger.Warn("No user ID in context for security check")
|
||||||
return fmt.Errorf("authentication required")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := hookCtx.Schema
|
schema := secCtx.GetSchema()
|
||||||
tablename := hookCtx.Entity
|
tablename := secCtx.GetEntity()
|
||||||
|
|
||||||
logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||||
|
|
||||||
// Load column security rules using the provider
|
// Load column security rules using the provider
|
||||||
err := securityList.LoadColumnSecurity(hookCtx.Context, userID, schema, tablename, false)
|
err := securityList.LoadColumnSecurity(secCtx.GetContext(), userID, schema, tablename, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to load column security: %v", err)
|
logger.Warn("Failed to load column security: %v", err)
|
||||||
// Don't fail the request if no security rules exist
|
// Don't fail the request if no security rules exist
|
||||||
@@ -53,7 +45,7 @@ func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load row security rules using the provider
|
// Load row security rules using the provider
|
||||||
_, err = securityList.LoadRowSecurity(hookCtx.Context, userID, schema, tablename, false)
|
_, err = securityList.LoadRowSecurity(secCtx.GetContext(), userID, schema, tablename, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to load row security: %v", err)
|
logger.Warn("Failed to load row security: %v", err)
|
||||||
// Don't fail the request if no security rules exist
|
// Don't fail the request if no security rules exist
|
||||||
@@ -63,15 +55,15 @@ func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyRowSecurity applies row-level security filters to the query
|
// applyRowSecurity applies row-level security filters to the query (generic version)
|
||||||
func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
userID, ok := GetUserID(hookCtx.Context)
|
userID, ok := secCtx.GetUserID()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil // No user context, skip
|
return nil // No user context, skip
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := hookCtx.Schema
|
schema := secCtx.GetSchema()
|
||||||
tablename := hookCtx.Entity
|
tablename := secCtx.GetEntity()
|
||||||
|
|
||||||
// Get row security template
|
// Get row security template
|
||||||
rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename)
|
rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename)
|
||||||
@@ -89,8 +81,14 @@ func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL
|
|||||||
|
|
||||||
// If there's a security template, apply it as a WHERE clause
|
// If there's a security template, apply it as a WHERE clause
|
||||||
if rowSec.Template != "" {
|
if rowSec.Template != "" {
|
||||||
|
model := secCtx.GetModel()
|
||||||
|
if model == nil {
|
||||||
|
logger.Debug("No model available for row security on %s.%s", schema, tablename)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get primary key name from model
|
// Get primary key name from model
|
||||||
modelType := reflect.TypeOf(hookCtx.Model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
@@ -117,39 +115,45 @@ func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL
|
|||||||
userID, schema, tablename, whereClause)
|
userID, schema, tablename, whereClause)
|
||||||
|
|
||||||
// Apply the WHERE clause to the query
|
// Apply the WHERE clause to the query
|
||||||
// The query is in hookCtx.Query
|
query := secCtx.GetQuery()
|
||||||
if selectQuery, ok := hookCtx.Query.(interface {
|
if selectQuery, ok := query.(interface {
|
||||||
Where(string, ...interface{}) interface{}
|
Where(string, ...interface{}) interface{}
|
||||||
}); ok {
|
}); ok {
|
||||||
hookCtx.Query = selectQuery.Where(whereClause)
|
secCtx.SetQuery(selectQuery.Where(whereClause))
|
||||||
} else {
|
} else {
|
||||||
logger.Error("Unable to apply WHERE clause - query doesn't support Where method")
|
logger.Debug("Query doesn't support Where method, skipping row security")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyColumnSecurity applies column-level security (masking/hiding) to results
|
// applyColumnSecurity applies column-level security (masking/hiding) to results (generic version)
|
||||||
func ApplyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
userID, ok := GetUserID(hookCtx.Context)
|
userID, ok := secCtx.GetUserID()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil // No user context, skip
|
return nil // No user context, skip
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := hookCtx.Schema
|
schema := secCtx.GetSchema()
|
||||||
tablename := hookCtx.Entity
|
tablename := secCtx.GetEntity()
|
||||||
|
|
||||||
// Get result data
|
// Get result data
|
||||||
result := hookCtx.Result
|
result := secCtx.GetResult()
|
||||||
if result == nil {
|
if result == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||||
|
|
||||||
|
model := secCtx.GetModel()
|
||||||
|
if model == nil {
|
||||||
|
logger.Debug("No model available for column security on %s.%s", schema, tablename)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get model type
|
// Get model type
|
||||||
modelType := reflect.TypeOf(hookCtx.Model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
@@ -169,37 +173,59 @@ func ApplyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *Securi
|
|||||||
|
|
||||||
// Update the result with masked data
|
// Update the result with masked data
|
||||||
if maskedResult.IsValid() && maskedResult.CanInterface() {
|
if maskedResult.IsValid() && maskedResult.CanInterface() {
|
||||||
hookCtx.Result = maskedResult.Interface()
|
secCtx.SetResult(maskedResult.Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogDataAccess logs all data access for audit purposes
|
// logDataAccess logs all data access for audit purposes (generic version)
|
||||||
func LogDataAccess(hookCtx *restheadspec.HookContext) error {
|
func logDataAccess(secCtx SecurityContext) error {
|
||||||
userID, _ := GetUserID(hookCtx.Context)
|
userID, _ := secCtx.GetUserID()
|
||||||
|
|
||||||
logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v",
|
logger.Info("AUDIT: User %d accessed %s.%s",
|
||||||
userID,
|
userID,
|
||||||
hookCtx.Schema,
|
secCtx.GetSchema(),
|
||||||
hookCtx.Entity,
|
secCtx.GetEntity(),
|
||||||
hookCtx.Options.Filters,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Write to audit log table or external audit service
|
// TODO: Write to audit log table or external audit service
|
||||||
// auditLog := AuditLog{
|
// auditLog := AuditLog{
|
||||||
// UserID: userID,
|
// UserID: userID,
|
||||||
// Schema: hookCtx.Schema,
|
// Schema: secCtx.GetSchema(),
|
||||||
// Entity: hookCtx.Entity,
|
// Entity: secCtx.GetEntity(),
|
||||||
// Action: "READ",
|
// Action: "READ",
|
||||||
// Timestamp: time.Now(),
|
// Timestamp: time.Now(),
|
||||||
// Filters: hookCtx.Options.Filters,
|
|
||||||
// }
|
// }
|
||||||
// db.Create(&auditLog)
|
// db.Create(&auditLog)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogDataAccess is a public wrapper for logDataAccess that accepts a SecurityContext
|
||||||
|
// This allows other packages to use the audit logging functionality
|
||||||
|
func LogDataAccess(secCtx SecurityContext) error {
|
||||||
|
return logDataAccess(secCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSecurityRules is a public wrapper for loadSecurityRules that accepts a SecurityContext
|
||||||
|
// This allows other packages to load security rules using the generic interface
|
||||||
|
func LoadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
|
return loadSecurityRules(secCtx, securityList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyRowSecurity is a public wrapper for applyRowSecurity that accepts a SecurityContext
|
||||||
|
// This allows other packages to apply row-level security using the generic interface
|
||||||
|
func ApplyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
|
return applyRowSecurity(secCtx, securityList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyColumnSecurity is a public wrapper for applyColumnSecurity that accepts a SecurityContext
|
||||||
|
// This allows other packages to apply column-level security using the generic interface
|
||||||
|
func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error {
|
||||||
|
return applyColumnSecurity(secCtx, securityList)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
func contains(s, substr string) bool {
|
func contains(s, substr string) bool {
|
||||||
|
|||||||
@@ -7,35 +7,37 @@ import (
|
|||||||
|
|
||||||
// UserContext holds authenticated user information
|
// UserContext holds authenticated user information
|
||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
UserID int
|
UserID int `json:"user_id"`
|
||||||
UserName string
|
UserName string `json:"user_name"`
|
||||||
UserLevel int
|
UserLevel int `json:"user_level"`
|
||||||
SessionID string
|
SessionID string `json:"session_id"`
|
||||||
RemoteID string
|
RemoteID string `json:"remote_id"`
|
||||||
Roles []string
|
Roles []string `json:"roles"`
|
||||||
Email string
|
Email string `json:"email"`
|
||||||
Claims map[string]any
|
Claims map[string]any `json:"claims"`
|
||||||
|
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginRequest contains credentials for login
|
// LoginRequest contains credentials for login
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Username string
|
Username string `json:"username"`
|
||||||
Password string
|
Password string `json:"password"`
|
||||||
Claims map[string]any // Additional login data
|
Claims map[string]any `json:"claims"` // Additional login data
|
||||||
|
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginResponse contains the result of a login attempt
|
// LoginResponse contains the result of a login attempt
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
Token string
|
Token string `json:"token"`
|
||||||
RefreshToken string
|
RefreshToken string `json:"refresh_token"`
|
||||||
User *UserContext
|
User *UserContext `json:"user"`
|
||||||
ExpiresIn int64 // Token expiration in seconds
|
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogoutRequest contains information for logout
|
// LogoutRequest contains information for logout
|
||||||
type LogoutRequest struct {
|
type LogoutRequest struct {
|
||||||
Token string
|
Token string `json:"token"`
|
||||||
UserID int
|
UserID int `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticator handles user authentication operations
|
// Authenticator handles user authentication operations
|
||||||
|
|||||||
@@ -10,21 +10,145 @@ type contextKey string
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// Context keys for user information
|
// Context keys for user information
|
||||||
UserIDKey contextKey = "user_id"
|
UserIDKey contextKey = "user_id"
|
||||||
UserNameKey contextKey = "user_name"
|
UserNameKey contextKey = "user_name"
|
||||||
UserLevelKey contextKey = "user_level"
|
UserLevelKey contextKey = "user_level"
|
||||||
SessionIDKey contextKey = "session_id"
|
SessionIDKey contextKey = "session_id"
|
||||||
RemoteIDKey contextKey = "remote_id"
|
RemoteIDKey contextKey = "remote_id"
|
||||||
UserRolesKey contextKey = "user_roles"
|
UserRolesKey contextKey = "user_roles"
|
||||||
UserEmailKey contextKey = "user_email"
|
UserEmailKey contextKey = "user_email"
|
||||||
UserContextKey contextKey = "user_context"
|
UserContextKey contextKey = "user_context"
|
||||||
|
UserMetaKey contextKey = "user_meta"
|
||||||
|
SkipAuthKey contextKey = "skip_auth"
|
||||||
|
OptionalAuthKey contextKey = "optional_auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SkipAuth returns a context with skip auth flag set to true
|
||||||
|
// Use this to mark routes that should bypass authentication middleware
|
||||||
|
func SkipAuth(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, SkipAuthKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionalAuth returns a context with optional auth flag set to true
|
||||||
|
// Use this to mark routes that should try to authenticate, but fall back to guest if authentication fails
|
||||||
|
func OptionalAuth(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, OptionalAuthKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createGuestContext creates a guest user context for unauthenticated requests
|
||||||
|
func createGuestContext(r *http.Request) *UserContext {
|
||||||
|
return &UserContext{
|
||||||
|
UserID: 0,
|
||||||
|
UserName: "guest",
|
||||||
|
UserLevel: 0,
|
||||||
|
SessionID: "",
|
||||||
|
RemoteID: r.RemoteAddr,
|
||||||
|
Roles: []string{"guest"},
|
||||||
|
Email: "",
|
||||||
|
Claims: map[string]any{},
|
||||||
|
Meta: map[string]any{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setUserContext adds a user context to the request context
|
||||||
|
func setUserContext(r *http.Request, userCtx *UserContext) *http.Request {
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserContextKey, userCtx)
|
||||||
|
ctx = context.WithValue(ctx, UserIDKey, userCtx.UserID)
|
||||||
|
ctx = context.WithValue(ctx, UserNameKey, userCtx.UserName)
|
||||||
|
ctx = context.WithValue(ctx, UserLevelKey, userCtx.UserLevel)
|
||||||
|
ctx = context.WithValue(ctx, SessionIDKey, userCtx.SessionID)
|
||||||
|
ctx = context.WithValue(ctx, RemoteIDKey, userCtx.RemoteID)
|
||||||
|
ctx = context.WithValue(ctx, UserRolesKey, userCtx.Roles)
|
||||||
|
|
||||||
|
if userCtx.Email != "" {
|
||||||
|
ctx = context.WithValue(ctx, UserEmailKey, userCtx.Email)
|
||||||
|
}
|
||||||
|
if len(userCtx.Meta) > 0 {
|
||||||
|
ctx = context.WithValue(ctx, UserMetaKey, userCtx.Meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateRequest performs authentication and adds user context to the request
|
||||||
|
// This is the shared authentication logic used by both handler and middleware
|
||||||
|
func authenticateRequest(w http.ResponseWriter, r *http.Request, provider SecurityProvider) (*http.Request, bool) {
|
||||||
|
// Call the provider's Authenticate method
|
||||||
|
userCtx, err := provider.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return setUserContext(r, userCtx), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthHandler creates an authentication handler that can be used standalone
|
||||||
|
// This handler performs authentication and returns 401 if authentication fails
|
||||||
|
// Use this when you need authentication logic without middleware wrapping
|
||||||
|
func NewAuthHandler(securityList *SecurityList, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get the security provider
|
||||||
|
provider := securityList.Provider()
|
||||||
|
if provider == nil {
|
||||||
|
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate the request
|
||||||
|
authenticatedReq, ok := authenticateRequest(w, r, provider)
|
||||||
|
if !ok {
|
||||||
|
return // authenticateRequest already wrote the error response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue with authenticated context
|
||||||
|
next.ServeHTTP(w, authenticatedReq)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOptionalAuthHandler creates an optional authentication handler that can be used standalone
|
||||||
|
// This handler tries to authenticate but falls back to guest context if authentication fails
|
||||||
|
// Use this for routes that should show personalized content for authenticated users but still work for guests
|
||||||
|
func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get the security provider
|
||||||
|
provider := securityList.Provider()
|
||||||
|
if provider == nil {
|
||||||
|
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to authenticate
|
||||||
|
userCtx, err := provider.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
// Authentication failed - set guest context and continue
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication succeeded - set user context
|
||||||
|
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates an authentication middleware with the given security list
|
// NewAuthMiddleware creates an authentication middleware with the given security list
|
||||||
// This middleware extracts user authentication from the request and adds it to context
|
// This middleware extracts user authentication from the request and adds it to context
|
||||||
|
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
||||||
|
// Routes can use optional authentication by setting OptionalAuthKey context value (use OptionalAuth helper)
|
||||||
|
// When authentication is skipped or fails with optional auth, a guest user context is set instead
|
||||||
func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if this route should skip authentication
|
||||||
|
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
|
||||||
|
// Set guest user context for skipped routes
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Get the security provider
|
// Get the security provider
|
||||||
provider := securityList.Provider()
|
provider := securityList.Provider()
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
@@ -32,31 +156,25 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the provider's Authenticate method
|
// Check if this route has optional authentication
|
||||||
|
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
|
||||||
|
|
||||||
|
// Try to authenticate
|
||||||
userCtx, err := provider.Authenticate(r)
|
userCtx, err := provider.Authenticate(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if optional {
|
||||||
|
// Optional auth failed - set guest context and continue
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Required auth failed - return error
|
||||||
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add user information to context
|
// Authentication succeeded - set user context
|
||||||
ctx := r.Context()
|
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||||
ctx = context.WithValue(ctx, UserContextKey, userCtx)
|
|
||||||
ctx = context.WithValue(ctx, UserIDKey, userCtx.UserID)
|
|
||||||
ctx = context.WithValue(ctx, UserNameKey, userCtx.UserName)
|
|
||||||
ctx = context.WithValue(ctx, UserLevelKey, userCtx.UserLevel)
|
|
||||||
ctx = context.WithValue(ctx, SessionIDKey, userCtx.SessionID)
|
|
||||||
ctx = context.WithValue(ctx, RemoteIDKey, userCtx.RemoteID)
|
|
||||||
|
|
||||||
if len(userCtx.Roles) > 0 {
|
|
||||||
ctx = context.WithValue(ctx, UserRolesKey, userCtx.Roles)
|
|
||||||
}
|
|
||||||
if userCtx.Email != "" {
|
|
||||||
ctx = context.WithValue(ctx, UserEmailKey, userCtx.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Continue with authenticated context
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -119,3 +237,164 @@ func GetUserEmail(ctx context.Context) (string, bool) {
|
|||||||
email, ok := ctx.Value(UserEmailKey).(string)
|
email, ok := ctx.Value(UserEmailKey).(string)
|
||||||
return email, ok
|
return email, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserMeta extracts user metadata from context
|
||||||
|
func GetUserMeta(ctx context.Context) (map[string]any, bool) {
|
||||||
|
meta, ok := ctx.Value(UserMetaKey).(map[string]any)
|
||||||
|
return meta, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// // Handler adapters for resolvespec/restheadspec compatibility
|
||||||
|
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
||||||
|
|
||||||
|
// // SpecHandlerAdapter is an interface for handler adapters that need authentication
|
||||||
|
// // Implement this interface to create adapters for custom handler types
|
||||||
|
// type SpecHandlerAdapter interface {
|
||||||
|
// // AdaptToHTTPHandler converts the custom handler to a standard http.Handler
|
||||||
|
// AdaptToHTTPHandler() http.Handler
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // ResolveSpecHandlerAdapter adapts a resolvespec/restheadspec handler method to http.Handler
|
||||||
|
// type ResolveSpecHandlerAdapter struct {
|
||||||
|
// // HandlerMethod is the method to call (e.g., handler.Handle, handler.HandleGet)
|
||||||
|
// HandlerMethod func(w any, r any, params map[string]string)
|
||||||
|
// // Params are the route parameters (e.g., {"schema": "public", "entity": "users"})
|
||||||
|
// Params map[string]string
|
||||||
|
// // RequestAdapter converts *http.Request to the custom Request interface
|
||||||
|
// // Use router.NewHTTPRequest from pkg/common/adapters/router
|
||||||
|
// RequestAdapter func(*http.Request) any
|
||||||
|
// // ResponseAdapter converts http.ResponseWriter to the custom ResponseWriter interface
|
||||||
|
// // Use router.NewHTTPResponseWriter from pkg/common/adapters/router
|
||||||
|
// ResponseAdapter func(http.ResponseWriter) any
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // AdaptToHTTPHandler implements SpecHandlerAdapter
|
||||||
|
// func (a *ResolveSpecHandlerAdapter) AdaptToHTTPHandler() http.Handler {
|
||||||
|
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// req := a.RequestAdapter(r)
|
||||||
|
// resp := a.ResponseAdapter(w)
|
||||||
|
// a.HandlerMethod(resp, req, a.Params)
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // WrapSpecHandler wraps a spec handler adapter with authentication
|
||||||
|
// // Use this to apply NewAuthHandler or NewOptionalAuthHandler to resolvespec/restheadspec handlers
|
||||||
|
// //
|
||||||
|
// // Example with required auth:
|
||||||
|
// //
|
||||||
|
// // adapter := &security.ResolveSpecHandlerAdapter{
|
||||||
|
// // HandlerMethod: handler.Handle,
|
||||||
|
// // Params: map[string]string{"schema": "public", "entity": "users"},
|
||||||
|
// // RequestAdapter: func(r *http.Request) any { return router.NewHTTPRequest(r) },
|
||||||
|
// // ResponseAdapter: func(w http.ResponseWriter) any { return router.NewHTTPResponseWriter(w) },
|
||||||
|
// // }
|
||||||
|
// // authHandler := security.WrapSpecHandler(securityList, adapter, false)
|
||||||
|
// // muxRouter.Handle("/api/users", authHandler)
|
||||||
|
// func WrapSpecHandler(securityList *SecurityList, adapter SpecHandlerAdapter, optional bool) http.Handler {
|
||||||
|
// httpHandler := adapter.AdaptToHTTPHandler()
|
||||||
|
// if optional {
|
||||||
|
// return NewOptionalAuthHandler(securityList, httpHandler)
|
||||||
|
// }
|
||||||
|
// return NewAuthHandler(securityList, httpHandler)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // MuxRouteBuilder helps build authenticated routes with Gorilla Mux
|
||||||
|
// type MuxRouteBuilder struct {
|
||||||
|
// securityList *SecurityList
|
||||||
|
// requestAdapter func(*http.Request) any
|
||||||
|
// responseAdapter func(http.ResponseWriter) any
|
||||||
|
// paramExtractor func(*http.Request) map[string]string
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // NewMuxRouteBuilder creates a route builder for Gorilla Mux with standard router adapters
|
||||||
|
// // Usage:
|
||||||
|
// //
|
||||||
|
// // builder := security.NewMuxRouteBuilder(securityList, router.NewHTTPRequest, router.NewHTTPResponseWriter)
|
||||||
|
// func NewMuxRouteBuilder(
|
||||||
|
// securityList *SecurityList,
|
||||||
|
// requestAdapter func(*http.Request) any,
|
||||||
|
// responseAdapter func(http.ResponseWriter) any,
|
||||||
|
// ) *MuxRouteBuilder {
|
||||||
|
// return &MuxRouteBuilder{
|
||||||
|
// securityList: securityList,
|
||||||
|
// requestAdapter: requestAdapter,
|
||||||
|
// responseAdapter: responseAdapter,
|
||||||
|
// paramExtractor: nil, // Will be set per route using mux.Vars
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // HandleAuth creates an authenticated route handler
|
||||||
|
// // pattern: the route pattern (e.g., "/{schema}/{entity}")
|
||||||
|
// // handler: the handler method to call (e.g., handler.Handle)
|
||||||
|
// // optional: true for optional auth (guest fallback), false for required auth (401 on failure)
|
||||||
|
// // methods: HTTP methods (e.g., "GET", "POST")
|
||||||
|
// //
|
||||||
|
// // Usage:
|
||||||
|
// //
|
||||||
|
// // builder.HandleAuth(router, "/{schema}/{entity}", handler.Handle, false, "POST")
|
||||||
|
// func (b *MuxRouteBuilder) HandleAuth(
|
||||||
|
// router interface {
|
||||||
|
// HandleFunc(pattern string, f func(http.ResponseWriter, *http.Request)) interface{ Methods(...string) interface{} }
|
||||||
|
// },
|
||||||
|
// pattern string,
|
||||||
|
// handlerMethod func(w any, r any, params map[string]string),
|
||||||
|
// optional bool,
|
||||||
|
// methods ...string,
|
||||||
|
// ) {
|
||||||
|
// router.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// // Extract params using the registered extractor or default to empty map
|
||||||
|
// var params map[string]string
|
||||||
|
// if b.paramExtractor != nil {
|
||||||
|
// params = b.paramExtractor(r)
|
||||||
|
// } else {
|
||||||
|
// params = make(map[string]string)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// adapter := &ResolveSpecHandlerAdapter{
|
||||||
|
// HandlerMethod: handlerMethod,
|
||||||
|
// Params: params,
|
||||||
|
// RequestAdapter: b.requestAdapter,
|
||||||
|
// ResponseAdapter: b.responseAdapter,
|
||||||
|
// }
|
||||||
|
// authHandler := WrapSpecHandler(b.securityList, adapter, optional)
|
||||||
|
// authHandler.ServeHTTP(w, r)
|
||||||
|
// }).Methods(methods...)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // SetParamExtractor sets a custom parameter extractor function
|
||||||
|
// // For Gorilla Mux, you would use: builder.SetParamExtractor(mux.Vars)
|
||||||
|
// func (b *MuxRouteBuilder) SetParamExtractor(extractor func(*http.Request) map[string]string) {
|
||||||
|
// b.paramExtractor = extractor
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // SetupAuthenticatedSpecRoutes sets up all standard resolvespec/restheadspec routes with authentication
|
||||||
|
// // This is a convenience function that sets up the common route patterns
|
||||||
|
// //
|
||||||
|
// // Usage:
|
||||||
|
// //
|
||||||
|
// // security.SetupAuthenticatedSpecRoutes(router, handler, securityList, router.NewHTTPRequest, router.NewHTTPResponseWriter, mux.Vars)
|
||||||
|
// func SetupAuthenticatedSpecRoutes(
|
||||||
|
// router interface {
|
||||||
|
// HandleFunc(pattern string, f func(http.ResponseWriter, *http.Request)) interface{ Methods(...string) interface{} }
|
||||||
|
// },
|
||||||
|
// handler interface {
|
||||||
|
// Handle(w any, r any, params map[string]string)
|
||||||
|
// HandleGet(w any, r any, params map[string]string)
|
||||||
|
// },
|
||||||
|
// securityList *SecurityList,
|
||||||
|
// requestAdapter func(*http.Request) any,
|
||||||
|
// responseAdapter func(http.ResponseWriter) any,
|
||||||
|
// paramExtractor func(*http.Request) map[string]string,
|
||||||
|
// ) {
|
||||||
|
// builder := NewMuxRouteBuilder(securityList, requestAdapter, responseAdapter)
|
||||||
|
// builder.SetParamExtractor(paramExtractor)
|
||||||
|
|
||||||
|
// // POST /{schema}/{entity}
|
||||||
|
// builder.HandleAuth(router, "/{schema}/{entity}", handler.Handle, false, "POST")
|
||||||
|
|
||||||
|
// // POST /{schema}/{entity}/{id}
|
||||||
|
// builder.HandleAuth(router, "/{schema}/{entity}/{id}", handler.Handle, false, "POST")
|
||||||
|
|
||||||
|
// // GET /{schema}/{entity}
|
||||||
|
// builder.HandleAuth(router, "/{schema}/{entity}", handler.HandleGet, false, "GET")
|
||||||
|
// }
|
||||||
|
|||||||
@@ -15,26 +15,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ColumnSecurity struct {
|
type ColumnSecurity struct {
|
||||||
Schema string
|
Schema string `json:"schema"`
|
||||||
Tablename string
|
Tablename string `json:"tablename"`
|
||||||
Path []string
|
Path []string `json:"path"`
|
||||||
ExtraFilters map[string]string
|
ExtraFilters map[string]string `json:"extra_filters"`
|
||||||
UserID int
|
UserID int `json:"user_id"`
|
||||||
Accesstype string `json:"accesstype"`
|
Accesstype string `json:"accesstype"`
|
||||||
MaskStart int
|
MaskStart int `json:"mask_start"`
|
||||||
MaskEnd int
|
MaskEnd int `json:"mask_end"`
|
||||||
MaskInvert bool
|
MaskInvert bool `json:"mask_invert"`
|
||||||
MaskChar string
|
MaskChar string `json:"mask_char"`
|
||||||
Control string `json:"control"`
|
Control string `json:"control"`
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RowSecurity struct {
|
type RowSecurity struct {
|
||||||
Schema string
|
Schema string `json:"schema"`
|
||||||
Tablename string
|
Tablename string `json:"tablename"`
|
||||||
Template string
|
Template string `json:"template"`
|
||||||
HasBlock bool
|
HasBlock bool `json:"has_block"`
|
||||||
UserID int
|
UserID int `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string {
|
func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string {
|
||||||
|
|||||||
@@ -1,292 +0,0 @@
|
|||||||
package security
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetupSecurityProvider initializes and configures the security provider
|
|
||||||
// This function creates a SecurityList with the given provider and registers hooks
|
|
||||||
//
|
|
||||||
// Example usage:
|
|
||||||
//
|
|
||||||
// // Create your security provider (use composite or single provider)
|
|
||||||
// auth := security.NewJWTAuthenticator("your-secret-key", db)
|
|
||||||
// colSec := security.NewDatabaseColumnSecurityProvider(db)
|
|
||||||
// rowSec := security.NewDatabaseRowSecurityProvider(db)
|
|
||||||
// provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
|
||||||
//
|
|
||||||
// // Setup security with the provider
|
|
||||||
// handler := restheadspec.NewHandlerWithGORM(db)
|
|
||||||
// securityList := security.SetupSecurityProvider(handler, provider)
|
|
||||||
//
|
|
||||||
// // Apply middleware
|
|
||||||
// router.Use(security.NewAuthMiddleware(securityList))
|
|
||||||
// router.Use(security.SetSecurityMiddleware(securityList))
|
|
||||||
func SetupSecurityProvider(handler *restheadspec.Handler, provider SecurityProvider) *SecurityList {
|
|
||||||
if provider == nil {
|
|
||||||
panic("security provider cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create security list with the provider
|
|
||||||
securityList := NewSecurityList(provider)
|
|
||||||
|
|
||||||
// Register all security hooks
|
|
||||||
RegisterSecurityHooks(handler, securityList)
|
|
||||||
|
|
||||||
return securityList
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 1: Complete Setup with Composite Provider and Database-Backed Security
|
|
||||||
// ===============================================================================
|
|
||||||
// Note: Security providers use *sql.DB, but restheadspec.Handler may use *gorm.DB
|
|
||||||
// You can get *sql.DB from gorm.DB using: sqlDB, _ := gormDB.DB()
|
|
||||||
|
|
||||||
func ExampleDatabaseSecurity(gormDB interface{}, sqlDB *sql.DB) (http.Handler, error) {
|
|
||||||
// Step 1: Create the ResolveSpec handler
|
|
||||||
// handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB))
|
|
||||||
handler := &restheadspec.Handler{} // Placeholder - use your handler initialization
|
|
||||||
|
|
||||||
// Step 2: Register your models
|
|
||||||
// handler.RegisterModel("public", "users", User{})
|
|
||||||
// handler.RegisterModel("public", "orders", Order{})
|
|
||||||
|
|
||||||
// Step 3: Create security provider components (using sql.DB)
|
|
||||||
auth := NewJWTAuthenticator("your-secret-key", sqlDB)
|
|
||||||
colSec := NewDatabaseColumnSecurityProvider(sqlDB)
|
|
||||||
rowSec := NewDatabaseRowSecurityProvider(sqlDB)
|
|
||||||
|
|
||||||
// Step 4: Combine into composite provider
|
|
||||||
provider := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
|
||||||
|
|
||||||
// Step 5: Setup security
|
|
||||||
securityList := SetupSecurityProvider(handler, provider)
|
|
||||||
|
|
||||||
// Step 6: Create router and setup routes
|
|
||||||
router := mux.NewRouter()
|
|
||||||
restheadspec.SetupMuxRoutes(router, handler)
|
|
||||||
|
|
||||||
// Step 7: Apply middleware in correct order
|
|
||||||
router.Use(NewAuthMiddleware(securityList))
|
|
||||||
router.Use(SetSecurityMiddleware(securityList))
|
|
||||||
|
|
||||||
return router, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 2: Simple Header-Based Authentication
|
|
||||||
// ==============================================
|
|
||||||
|
|
||||||
func ExampleHeaderAuthentication(gormDB interface{}, sqlDB *sql.DB) (*mux.Router, error) {
|
|
||||||
// handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB))
|
|
||||||
handler := &restheadspec.Handler{} // Placeholder - use your handler initialization
|
|
||||||
|
|
||||||
// Use header-based auth with database security providers
|
|
||||||
auth := NewHeaderAuthenticatorExample()
|
|
||||||
colSec := NewDatabaseColumnSecurityProvider(sqlDB)
|
|
||||||
rowSec := NewDatabaseRowSecurityProvider(sqlDB)
|
|
||||||
|
|
||||||
provider := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
|
||||||
securityList := SetupSecurityProvider(handler, provider)
|
|
||||||
|
|
||||||
router := mux.NewRouter()
|
|
||||||
restheadspec.SetupMuxRoutes(router, handler)
|
|
||||||
|
|
||||||
router.Use(NewAuthMiddleware(securityList))
|
|
||||||
router.Use(SetSecurityMiddleware(securityList))
|
|
||||||
|
|
||||||
return router, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 3: Config-Based Security (No Database for Security)
|
|
||||||
// ===========================================================
|
|
||||||
|
|
||||||
func ExampleConfigSecurity(gormDB interface{}) (*mux.Router, error) {
|
|
||||||
// handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB))
|
|
||||||
handler := &restheadspec.Handler{} // Placeholder - use your handler initialization
|
|
||||||
|
|
||||||
// Define column security rules in code
|
|
||||||
columnRules := map[string][]ColumnSecurity{
|
|
||||||
"public.employees": {
|
|
||||||
{
|
|
||||||
Schema: "public",
|
|
||||||
Tablename: "employees",
|
|
||||||
Path: []string{"ssn"},
|
|
||||||
Accesstype: "mask",
|
|
||||||
MaskStart: 5,
|
|
||||||
MaskChar: "*",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Schema: "public",
|
|
||||||
Tablename: "employees",
|
|
||||||
Path: []string{"salary"},
|
|
||||||
Accesstype: "hide",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define row security templates
|
|
||||||
rowTemplates := map[string]string{
|
|
||||||
"public.orders": "user_id = {UserID}",
|
|
||||||
"public.documents": "user_id = {UserID} OR is_public = true",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define blocked tables
|
|
||||||
blockedTables := map[string]bool{
|
|
||||||
"public.admin_logs": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create providers
|
|
||||||
auth := NewHeaderAuthenticatorExample()
|
|
||||||
colSec := NewConfigColumnSecurityProvider(columnRules)
|
|
||||||
rowSec := NewConfigRowSecurityProvider(rowTemplates, blockedTables)
|
|
||||||
|
|
||||||
provider := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
|
||||||
securityList := SetupSecurityProvider(handler, provider)
|
|
||||||
|
|
||||||
router := mux.NewRouter()
|
|
||||||
restheadspec.SetupMuxRoutes(router, handler)
|
|
||||||
|
|
||||||
router.Use(NewAuthMiddleware(securityList))
|
|
||||||
router.Use(SetSecurityMiddleware(securityList))
|
|
||||||
|
|
||||||
return router, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 4: Custom Security Provider
|
|
||||||
// ====================================
|
|
||||||
|
|
||||||
// You can implement your own SecurityProvider by implementing all three interfaces
|
|
||||||
type CustomSecurityProvider struct {
|
|
||||||
// Your custom fields
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CustomSecurityProvider) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
|
||||||
// Your custom login logic
|
|
||||||
return nil, fmt.Errorf("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CustomSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error {
|
|
||||||
// Your custom logout logic
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CustomSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) {
|
|
||||||
// Your custom authentication logic
|
|
||||||
return nil, fmt.Errorf("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CustomSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
|
||||||
// Your custom column security logic
|
|
||||||
return []ColumnSecurity{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CustomSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
|
||||||
// Your custom row security logic
|
|
||||||
return RowSecurity{
|
|
||||||
Schema: schema,
|
|
||||||
Tablename: table,
|
|
||||||
UserID: userID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 5: Adding Login/Logout Endpoints
|
|
||||||
// =========================================
|
|
||||||
|
|
||||||
func SetupAuthRoutes(router *mux.Router, securityList *SecurityList) {
|
|
||||||
// Login endpoint
|
|
||||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Parse login request
|
|
||||||
var loginReq LoginRequest
|
|
||||||
// json.NewDecoder(r.Body).Decode(&loginReq)
|
|
||||||
|
|
||||||
// Call provider's Login method
|
|
||||||
resp, err := securityList.Provider().Login(r.Context(), loginReq)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return token
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
// json.NewEncoder(w).Encode(resp)
|
|
||||||
fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn)
|
|
||||||
}).Methods("POST")
|
|
||||||
|
|
||||||
// Logout endpoint
|
|
||||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Extract token from header
|
|
||||||
token := r.Header.Get("Authorization")
|
|
||||||
|
|
||||||
// Get user ID from context (if authenticated)
|
|
||||||
userID, _ := GetUserID(r.Context())
|
|
||||||
|
|
||||||
// Call provider's Logout method
|
|
||||||
err := securityList.Provider().Logout(r.Context(), LogoutRequest{
|
|
||||||
Token: token,
|
|
||||||
UserID: userID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
fmt.Fprint(w, `{"success": true}`)
|
|
||||||
}).Methods("POST")
|
|
||||||
|
|
||||||
// Optional: Token refresh endpoint
|
|
||||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
refreshToken := r.Header.Get("X-Refresh-Token")
|
|
||||||
|
|
||||||
// Check if provider supports refresh
|
|
||||||
if refreshable, ok := securityList.Provider().(Refreshable); ok {
|
|
||||||
resp, err := refreshable.RefreshToken(r.Context(), refreshToken)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn)
|
|
||||||
} else {
|
|
||||||
http.Error(w, "Token refresh not supported", http.StatusNotImplemented)
|
|
||||||
}
|
|
||||||
}).Methods("POST")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Example 6: Complete Server Setup
|
|
||||||
// =================================
|
|
||||||
|
|
||||||
func CompleteServerExample(gormDB interface{}, sqlDB *sql.DB) http.Handler {
|
|
||||||
// Create handler and register models
|
|
||||||
// handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB))
|
|
||||||
handler := &restheadspec.Handler{} // Placeholder - use your handler initialization
|
|
||||||
// handler.RegisterModel("public", "users", User{})
|
|
||||||
|
|
||||||
// Setup security (using sql.DB for security providers)
|
|
||||||
auth := NewJWTAuthenticator("secret-key", sqlDB)
|
|
||||||
colSec := NewDatabaseColumnSecurityProvider(sqlDB)
|
|
||||||
rowSec := NewDatabaseRowSecurityProvider(sqlDB)
|
|
||||||
provider := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
|
||||||
securityList := SetupSecurityProvider(handler, provider)
|
|
||||||
|
|
||||||
// Create router
|
|
||||||
router := mux.NewRouter()
|
|
||||||
|
|
||||||
// Add auth routes (login/logout)
|
|
||||||
SetupAuthRoutes(router, securityList)
|
|
||||||
|
|
||||||
// Add API routes with security middleware
|
|
||||||
apiRouter := router.PathPrefix("/api").Subrouter()
|
|
||||||
restheadspec.SetupMuxRoutes(apiRouter, handler)
|
|
||||||
apiRouter.Use(NewAuthMiddleware(securityList))
|
|
||||||
apiRouter.Use(SetSecurityMiddleware(securityList))
|
|
||||||
|
|
||||||
return router
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user