mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-03-12 08:26:49 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 254102bfac | |||
| 6c27419dbc | |||
| 377336caf4 | |||
| 79720d5421 | |||
| e7ab0a20d6 | |||
| e4087104a9 |
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -925,3 +926,36 @@ func extractLeftSideOfComparison(cond string) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
|
||||
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
|
||||
// Returns a single-element slice if the value is not a slice type.
|
||||
func FilterValueToSlice(v interface{}) []interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Slice {
|
||||
result := make([]interface{}, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
result[i] = rv.Index(i).Interface()
|
||||
}
|
||||
return result
|
||||
}
|
||||
return []interface{}{v}
|
||||
}
|
||||
|
||||
// BuildInCondition builds a parameterized IN condition from a filter value.
|
||||
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
|
||||
// Returns ("", nil) if the value is empty or not a slice.
|
||||
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
|
||||
values := FilterValueToSlice(v)
|
||||
if len(values) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
placeholders := make([]string, len(values))
|
||||
for i := range values {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
|
||||
}
|
||||
|
||||
@@ -2,14 +2,38 @@ package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||
// We provide audit logging for data access tracking
|
||||
// We provide auth enforcement and audit logging for data access tracking
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeQueryList - Auth check before list query execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 0: BeforeQuery - Auth check before single query execution
|
||||
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
// ModelRules defines the permissions and security settings for a model
|
||||
type ModelRules struct {
|
||||
CanPublicRead bool // Whether the model can be read (GET operations)
|
||||
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanPublicCreate bool // Whether the model can be created (POST operations)
|
||||
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
|
||||
CanRead bool // Whether the model can be read (GET operations)
|
||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanCreate bool // Whether the model can be created (POST operations)
|
||||
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
|
||||
CanUpdate: true,
|
||||
CanCreate: true,
|
||||
CanDelete: true,
|
||||
CanPublicRead: false,
|
||||
CanPublicUpdate: false,
|
||||
CanPublicCreate: false,
|
||||
CanPublicDelete: false,
|
||||
SecurityDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
|
||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||
- **Database Agnostic**: GORM and Bun ORM support
|
||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
||||
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
|
||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||
- **Thread-safe**: Proper concurrency handling throughout
|
||||
|
||||
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
|
||||
|
||||
### Hook Types
|
||||
|
||||
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||
- `BeforeRead` / `AfterRead` - Read operations
|
||||
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||
|
||||
### Security Hooks (Recommended)
|
||||
|
||||
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
mqttspec.RegisterSecurityHooks(handler, securityList)
|
||||
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||
```
|
||||
|
||||
### Authentication Example (JWT)
|
||||
|
||||
```go
|
||||
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
||||
| **Hooks** | Same 13 hooks | Same 13 hooks |
|
||||
| **CRUD Operations** | Identical | Identical |
|
||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||
|
||||
|
||||
@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
|
||||
},
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
hookCtx.Operation = string(msg.Operation)
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
if hookCtx.Abort {
|
||||
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
|
||||
@@ -20,8 +20,11 @@ type (
|
||||
HookRegistry = websocketspec.HookRegistry
|
||||
)
|
||||
|
||||
// Hook type constants - all 12 lifecycle hooks
|
||||
// Hook type constants - all lifecycle hooks
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch
|
||||
BeforeHandle = websocketspec.BeforeHandle
|
||||
|
||||
// CRUD operation hooks
|
||||
BeforeRead = websocketspec.BeforeRead
|
||||
AfterRead = websocketspec.AfterRead
|
||||
|
||||
108
pkg/mqttspec/security_hooks.go
Normal file
108
pkg/mqttspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for mqttspec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
// GetQuery retrieves a stored query from hook metadata
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
if s.ctx.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return s.ctx.Metadata["query"]
|
||||
}
|
||||
|
||||
// SetQuery stores the query in hook metadata
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if s.ctx.Metadata == nil {
|
||||
s.ctx.Metadata = make(map[string]interface{})
|
||||
}
|
||||
s.ctx.Metadata["query"] = query
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
@@ -644,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||
}
|
||||
|
||||
// Schema.Table format
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
@@ -654,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Operator: "in",
|
||||
Value: []string{"active", "pending"},
|
||||
},
|
||||
expectedCondition: "status IN (?)",
|
||||
expectedArgsCount: 1,
|
||||
expectedCondition: "status IN (?,?)",
|
||||
expectedArgsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "LIKE operator",
|
||||
|
||||
@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
validator := common.NewColumnValidator(model)
|
||||
req.Options = validator.FilterRequestOptions(req.Options)
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: req.Operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Operation {
|
||||
case "read":
|
||||
h.handleRead(ctx, w, id, req.Options)
|
||||
@@ -1236,6 +1256,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("BeforeDelete hook failed: %v", err)
|
||||
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle batch delete from request data
|
||||
if data != nil {
|
||||
switch v := data.(type) {
|
||||
@@ -1508,8 +1546,10 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return "", nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
@@ -1550,8 +1590,10 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return query
|
||||
}
|
||||
default:
|
||||
return query
|
||||
}
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -43,6 +47,9 @@ type HookContext struct {
|
||||
Writer common.ResponseWriter
|
||||
Request common.Request
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
|
||||
@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
||||
postEntityHandler = authMiddleware(postEntityHandler)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
|
||||
getEntityHandler = authMiddleware(getEntityHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -225,7 +225,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
|
||||
@@ -2,6 +2,7 @@ package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -10,6 +11,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvespec handler")
|
||||
}
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
```
|
||||
|
||||
**Available Hook Types**:
|
||||
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
* `BeforeRead`, `AfterRead`
|
||||
* `BeforeCreate`, `AfterCreate`
|
||||
* `BeforeUpdate`, `AfterUpdate`
|
||||
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
* `Handler`: Access to handler, database, and registry
|
||||
* `Schema`, `Entity`, `TableName`: Request info
|
||||
* `Model`: The registered model type
|
||||
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||
* `ID`: Record ID (for single-record operations)
|
||||
* `Data`: Request data (for create/update)
|
||||
* `Result`: Operation result (for after hooks)
|
||||
* `Writer`: Response writer (allows hooks to modify response)
|
||||
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
|
||||
@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
// Add request-scoped data to context (including options)
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||
|
||||
// Derive operation for auth check
|
||||
var operation string
|
||||
switch method {
|
||||
case "GET":
|
||||
operation = "read"
|
||||
case "POST":
|
||||
operation = "create"
|
||||
case "PUT", "PATCH":
|
||||
operation = "update"
|
||||
case "DELETE":
|
||||
operation = "delete"
|
||||
default:
|
||||
operation = "read"
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "GET":
|
||||
if id != "" {
|
||||
@@ -1498,8 +1533,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1572,8 +1607,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1630,8 +1665,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -2111,7 +2146,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
// Column is already cast to TEXT if needed
|
||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||
case "in":
|
||||
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
if cond == "" {
|
||||
return query
|
||||
}
|
||||
return applyWhere(cond, inArgs...)
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
@@ -2204,7 +2243,8 @@ func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.Fi
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "in":
|
||||
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
return cond, inArgs
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -42,6 +46,9 @@ type HookContext struct {
|
||||
Model interface{}
|
||||
Options ExtendedRequestOptions
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
@@ -56,6 +63,14 @@ type HookContext struct {
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
|
||||
// Request - the original HTTP request
|
||||
Request common.Request
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// logger.Debug("All hooks for %s executed successfully", hookType)
|
||||
|
||||
@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
entityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
||||
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
||||
entityHandler = authMiddleware(entityHandler)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler)
|
||||
metadataHandler = authMiddleware(metadataHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -289,7 +289,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = handler(w, req)
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
|
||||
@@ -2,6 +2,7 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
@@ -9,6 +10,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for restheadspec handler")
|
||||
}
|
||||
|
||||
|
||||
@@ -405,11 +405,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext to context)
|
||||
NewOptionalAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext or guest context; never 401)
|
||||
SetSecurityMiddleware → adds SecurityList to context
|
||||
↓
|
||||
Handler.Handle()
|
||||
Handler.Handle() → resolves model
|
||||
↓
|
||||
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
|
||||
├─ SecurityDisabled → allow
|
||||
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
└─ UserID == 0 → abort 401
|
||||
↓
|
||||
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
||||
↓
|
||||
@@ -693,15 +698,30 @@ http.Handle("/api/protected", authHandler)
|
||||
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||
http.Handle("/home", optionalHandler)
|
||||
|
||||
// Example handler
|
||||
func myHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
if userCtx.UserID == 0 {
|
||||
// Guest user
|
||||
} else {
|
||||
// Authenticated user
|
||||
}
|
||||
}
|
||||
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
```go
|
||||
// Register model with rules (pkg/modelregistry)
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // skip all auth when true
|
||||
CanPublicRead: true, // unauthenticated reads allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated can update
|
||||
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
|
||||
})
|
||||
|
||||
// CheckModelAuthAllowed used automatically in BeforeHandle hook
|
||||
// No code needed — call RegisterSecurityHooks and it's applied
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -751,14 +751,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware (security package)
|
||||
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
|
||||
├─ Calls provider.Authenticate(request)
|
||||
└─ Adds UserContext to context
|
||||
├─ On success: adds authenticated UserContext to context
|
||||
└─ On failure: adds guest UserContext (UserID=0) to context
|
||||
↓
|
||||
SetSecurityMiddleware (security package)
|
||||
└─ Adds SecurityList to context
|
||||
↓
|
||||
Spec Handler (restheadspec/funcspec/resolvespec)
|
||||
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
|
||||
└─ Resolves schema + entity + model from request
|
||||
↓
|
||||
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
|
||||
│ ├─ Loads model rules from context or registry
|
||||
│ ├─ SecurityDisabled → allow
|
||||
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
│ └─ UserID == 0 → 401 unauthorized
|
||||
└─ On error: aborts with 401
|
||||
↓
|
||||
BeforeRead Hook (registered by spec)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
@@ -784,7 +795,8 @@ HTTP Response (secured data)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Security package is spec-agnostic and provides core logic
|
||||
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
|
||||
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
|
||||
- Each spec registers its own hooks that adapt to SecurityContext
|
||||
- Security rules are loaded once and cached for the request
|
||||
- Row security is applied to the query (database level)
|
||||
@@ -1002,15 +1014,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
||||
}
|
||||
```
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
|
||||
|
||||
```go
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // true = skip all auth checks
|
||||
CanPublicRead: true, // unauthenticated GET allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated users can update
|
||||
CanDelete: false, // authenticated users cannot delete
|
||||
})
|
||||
```
|
||||
|
||||
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
|
||||
1. `SecurityDisabled` → allow all
|
||||
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
|
||||
3. Guest (UserID == 0) → return 401
|
||||
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
|
||||
|
||||
---
|
||||
|
||||
## Middleware and Handler API
|
||||
|
||||
### NewAuthMiddleware
|
||||
Standard middleware that authenticates all requests:
|
||||
Standard middleware that authenticates all requests and returns 401 on failure:
|
||||
|
||||
```go
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
```
|
||||
|
||||
### NewOptionalAuthMiddleware
|
||||
Middleware for spec routes — always continues; sets guest context on auth failure:
|
||||
|
||||
```go
|
||||
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
|
||||
```
|
||||
|
||||
Routes can skip authentication using the `SkipAuth` helper:
|
||||
|
||||
```go
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
||||
return applyColumnSecurity(secCtx, securityList)
|
||||
}
|
||||
|
||||
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanUpdate {
|
||||
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanDelete {
|
||||
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
|
||||
// model rules and the current user's authentication state. It is intended for use in
|
||||
// a BeforeHandle hook, fired after model resolution.
|
||||
//
|
||||
// Logic:
|
||||
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
|
||||
// 2. SecurityDisabled → allow.
|
||||
// 3. operation == "read" && CanPublicRead → allow.
|
||||
// 4. operation == "create" && CanPublicCreate → allow.
|
||||
// 5. operation == "update" && CanPublicUpdate → allow.
|
||||
// 6. operation == "delete" && CanPublicDelete → allow.
|
||||
// 7. Guest (UserID == 0) → return "authentication required".
|
||||
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
|
||||
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
// Model not registered - fall through to auth check
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
return nil
|
||||
}
|
||||
if operation == "read" && rules.CanPublicRead {
|
||||
return nil
|
||||
}
|
||||
if operation == "create" && rules.CanPublicCreate {
|
||||
return nil
|
||||
}
|
||||
if operation == "update" && rules.CanPublicUpdate {
|
||||
return nil
|
||||
}
|
||||
if operation == "delete" && rules.CanPublicDelete {
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
|
||||
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
return checkModelUpdateAllowed(secCtx)
|
||||
}
|
||||
|
||||
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
|
||||
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
return checkModelDeleteAllowed(secCtx)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
@@ -23,6 +25,7 @@ const (
|
||||
UserMetaKey contextKey = "user_meta"
|
||||
SkipAuthKey contextKey = "skip_auth"
|
||||
OptionalAuthKey contextKey = "optional_auth"
|
||||
ModelRulesKey contextKey = "model_rules"
|
||||
)
|
||||
|
||||
// SkipAuth returns a context with skip auth flag set to true
|
||||
@@ -136,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
|
||||
})
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
|
||||
// On auth failure, a guest user context is set instead of returning 401.
|
||||
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
|
||||
// after model resolution.
|
||||
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates an authentication middleware with the given security list
|
||||
// This middleware extracts user authentication from the request and adds it to context
|
||||
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
||||
@@ -182,6 +210,68 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
|
||||
}
|
||||
}
|
||||
|
||||
// NewModelAuthMiddleware creates authentication middleware that respects ModelRules for the given model name.
|
||||
// It first checks if ModelRules are set for the model:
|
||||
// - If SecurityDisabled is true, authentication is skipped and a guest context is set.
|
||||
// - Otherwise, all checks from NewAuthMiddleware apply (SkipAuthKey, provider check, OptionalAuthKey, Authenticate).
|
||||
//
|
||||
// If the model is not found in any registry, the middleware falls back to standard NewAuthMiddleware behaviour.
|
||||
func NewModelAuthMiddleware(securityList *SecurityList, modelName string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check ModelRules first
|
||||
if rules, err := modelregistry.GetModelRulesByName(modelName); err == nil {
|
||||
// Store rules in context for downstream use (e.g., security hooks)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ModelRulesKey, rules))
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
isRead := r.Method == http.MethodGet || r.Method == http.MethodHead
|
||||
isUpdate := r.Method == http.MethodPut || r.Method == http.MethodPatch
|
||||
if (isRead && rules.CanPublicRead) || (isUpdate && rules.CanPublicUpdate) {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this route should skip authentication
|
||||
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
// Get the security provider
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this route has optional authentication
|
||||
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
|
||||
|
||||
// Try to authenticate
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
if optional {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecurityMiddleware adds security context to requests
|
||||
// This middleware should be applied after AuthMiddleware
|
||||
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
@@ -366,6 +456,12 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
|
||||
return meta, ok
|
||||
}
|
||||
|
||||
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
|
||||
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
|
||||
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)
|
||||
return rules, ok
|
||||
}
|
||||
|
||||
// // Handler adapters for resolvespec/restheadspec compatibility
|
||||
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
||||
|
||||
|
||||
@@ -330,6 +330,7 @@ Hooks allow you to intercept and modify operations at various points in the life
|
||||
|
||||
### Available Hook Types
|
||||
|
||||
- **BeforeHandle** — fires after model resolution, before operation dispatch (auth checks)
|
||||
- **BeforeRead** / **AfterRead**
|
||||
- **BeforeCreate** / **AfterCreate**
|
||||
- **BeforeUpdate** / **AfterUpdate**
|
||||
@@ -337,6 +338,8 @@ Hooks allow you to intercept and modify operations at various points in the life
|
||||
- **BeforeSubscribe** / **AfterSubscribe**
|
||||
- **BeforeConnect** / **AfterConnect**
|
||||
|
||||
`HookContext` includes `Operation string` (`"read"`, `"create"`, `"update"`, `"delete"`) and `Abort bool`, `AbortMessage string`, `AbortCode int` for abort signaling.
|
||||
|
||||
### Hook Example
|
||||
|
||||
```go
|
||||
@@ -599,7 +602,19 @@ asyncio.run(main())
|
||||
|
||||
## Authentication
|
||||
|
||||
Implement authentication using hooks:
|
||||
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
websocketspec.RegisterSecurityHooks(handler, securityList)
|
||||
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||
```
|
||||
|
||||
Or implement custom authentication using hooks directly:
|
||||
|
||||
```go
|
||||
handler := websocketspec.NewHandlerWithGORM(db)
|
||||
|
||||
@@ -177,6 +177,16 @@ func (h *Handler) handleRequest(conn *Connection, msg *Message) {
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
hookCtx.Operation = string(msg.Operation)
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
if hookCtx.Abort {
|
||||
errResp := NewErrorResponse(msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||
_ = conn.SendJSON(errResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
@@ -618,7 +628,10 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if hookCtx.Options != nil {
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
cond, args := h.buildFilterCondition(filter)
|
||||
if cond != "" {
|
||||
countQuery = countQuery.Where(cond, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
count, _ := countQuery.Count(hookCtx.Context)
|
||||
@@ -790,14 +803,12 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
if strings.EqualFold(filter.Operator, "in") {
|
||||
cond, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return cond, args
|
||||
}
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
|
||||
args = []interface{}{filter.Value}
|
||||
|
||||
return condition, args
|
||||
return fmt.Sprintf("%s %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
|
||||
@@ -2,6 +2,7 @@ package websocketspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// BeforeRead is called before a read operation
|
||||
BeforeRead HookType = "before_read"
|
||||
// AfterRead is called after a read operation
|
||||
@@ -83,6 +88,9 @@ type HookContext struct {
|
||||
// Options contains the parsed request options
|
||||
Options *common.RequestOptions
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// ID is the record ID for single-record operations
|
||||
ID string
|
||||
|
||||
@@ -98,6 +106,11 @@ type HookContext struct {
|
||||
// Error is any error that occurred (for after hooks)
|
||||
Error error
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Metadata is additional context data
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
@@ -171,6 +184,11 @@ func (hr *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
if err := hook(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
108
pkg/websocketspec/security_hooks.go
Normal file
108
pkg/websocketspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package websocketspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for websocketspec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts websocketspec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
// GetQuery retrieves a stored query from hook metadata (websocketspec has no Query field)
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
if s.ctx.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return s.ctx.Metadata["query"]
|
||||
}
|
||||
|
||||
// SetQuery stores the query in hook metadata
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if s.ctx.Metadata == nil {
|
||||
s.ctx.Metadata = make(map[string]interface{})
|
||||
}
|
||||
s.ctx.Metadata["query"] = query
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
Reference in New Issue
Block a user