diff --git a/pkg/funcspec/security_adapter.go b/pkg/funcspec/security_adapter.go new file mode 100644 index 0000000..49f7fdc --- /dev/null +++ b/pkg/funcspec/security_adapter.go @@ -0,0 +1,83 @@ +package funcspec + +import ( + "context" + + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// RegisterSecurityHooks registers security hooks for funcspec handlers +// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable +// We provide audit logging for data access tracking +func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) { + // Hook 1: BeforeQueryList - Audit logging before query list execution + handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error { + secCtx := newFuncSpecSecurityContext(hookCtx) + return security.LogDataAccess(secCtx) + }) + + // Hook 2: BeforeQuery - Audit logging before single query execution + handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error { + secCtx := newFuncSpecSecurityContext(hookCtx) + return security.LogDataAccess(secCtx) + }) + + // Note: Row-level security and column masking are challenging in funcspec + // because the SQL query is fully user-defined. Security should be implemented + // at the SQL function level or through database policies (RLS). +} + +// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface +type funcSpecSecurityContext struct { + ctx *HookContext +} + +func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext { + return &funcSpecSecurityContext{ctx: ctx} +} + +func (f *funcSpecSecurityContext) GetContext() context.Context { + return f.ctx.Context +} + +func (f *funcSpecSecurityContext) GetUserID() (int, bool) { + if f.ctx.UserContext == nil { + return 0, false + } + return int(f.ctx.UserContext.UserID), true +} + +func (f *funcSpecSecurityContext) GetSchema() string { + // funcspec doesn't have a schema concept, extract from SQL query or use default + return "public" +} + +func (f *funcSpecSecurityContext) GetEntity() string { + // funcspec doesn't have an entity concept, could parse from SQL or use a placeholder + return "sql_query" +} + +func (f *funcSpecSecurityContext) GetModel() interface{} { + // funcspec doesn't use models in the same way as restheadspec + return nil +} + +func (f *funcSpecSecurityContext) GetQuery() interface{} { + // In funcspec, the query is a string, not a query builder object + return f.ctx.SQLQuery +} + +func (f *funcSpecSecurityContext) SetQuery(query interface{}) { + // In funcspec, we could modify the SQL string, but this should be done cautiously + if sqlQuery, ok := query.(string); ok { + f.ctx.SQLQuery = sqlQuery + } +} + +func (f *funcSpecSecurityContext) GetResult() interface{} { + return f.ctx.Result +} + +func (f *funcSpecSecurityContext) SetResult(result interface{}) { + f.ctx.Result = result +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 6c3853a..3beb261 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -21,6 +21,7 @@ type Handler struct { db common.Database registry common.ModelRegistry nestedProcessor *common.NestedCUDProcessor + hooks *HookRegistry } // NewHandler creates a new API handler with database and registry abstractions @@ -28,12 +29,19 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { handler := &Handler{ db: db, registry: registry, + hooks: NewHookRegistry(), } // Initialize nested processor handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler) return handler } +// Hooks returns the hook registry for this handler +// Use this to register custom hooks for operations +func (h *Handler) Hooks() *HookRegistry { + return h.hooks +} + // GetDatabase returns the underlying database connection // Implements common.SpecHandler interface func (h *Handler) GetDatabase() common.Database { diff --git a/pkg/resolvespec/hooks.go b/pkg/resolvespec/hooks.go new file mode 100644 index 0000000..d269b5c --- /dev/null +++ b/pkg/resolvespec/hooks.go @@ -0,0 +1,152 @@ +package resolvespec + +import ( + "context" + "fmt" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// HookType defines the type of hook to execute +type HookType string + +const ( + // Read operation hooks + BeforeRead HookType = "before_read" + AfterRead HookType = "after_read" + + // Create operation hooks + BeforeCreate HookType = "before_create" + AfterCreate HookType = "after_create" + + // Update operation hooks + BeforeUpdate HookType = "before_update" + AfterUpdate HookType = "after_update" + + // Delete operation hooks + BeforeDelete HookType = "before_delete" + AfterDelete HookType = "after_delete" + + // Scan/Execute operation hooks (for query building) + BeforeScan HookType = "before_scan" +) + +// HookContext contains all the data available to a hook +type HookContext struct { + Context context.Context + Handler *Handler // Reference to the handler for accessing database, registry, etc. + Schema string + Entity string + Model interface{} + Options common.RequestOptions + Writer common.ResponseWriter + Request common.Request + + // Operation-specific fields + ID string + Data interface{} // For create/update operations + Result interface{} // For after hooks + Error error // For after hooks + + // Query chain - allows hooks to modify the query before execution + Query common.SelectQuery + + // Allow hooks to abort the operation + Abort bool // If set to true, the operation will be aborted + AbortMessage string // Message to return if aborted + AbortCode int // HTTP status code if aborted +} + +// HookFunc is the signature for hook functions +// It receives a HookContext and can modify it or return an error +// If an error is returned, the operation will be aborted +type HookFunc func(*HookContext) error + +// HookRegistry manages all registered hooks +type HookRegistry struct { + hooks map[HookType][]HookFunc +} + +// NewHookRegistry creates a new hook registry +func NewHookRegistry() *HookRegistry { + return &HookRegistry{ + hooks: make(map[HookType][]HookFunc), + } +} + +// Register adds a new hook for the specified hook type +func (r *HookRegistry) Register(hookType HookType, hook HookFunc) { + if r.hooks == nil { + r.hooks = make(map[HookType][]HookFunc) + } + r.hooks[hookType] = append(r.hooks[hookType], hook) + logger.Info("Registered resolvespec hook for %s (total: %d)", hookType, len(r.hooks[hookType])) +} + +// RegisterMultiple registers a hook for multiple hook types +func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) { + for _, hookType := range hookTypes { + r.Register(hookType, hook) + } +} + +// Execute runs all hooks for the specified type in order +// If any hook returns an error, execution stops and the error is returned +func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error { + hooks, exists := r.hooks[hookType] + if !exists || len(hooks) == 0 { + return nil + } + + logger.Debug("Executing %d resolvespec hook(s) for %s", len(hooks), hookType) + + for i, hook := range hooks { + if err := hook(ctx); err != nil { + logger.Error("Resolvespec hook %d for %s failed: %v", i+1, hookType, err) + return fmt.Errorf("hook execution failed: %w", err) + } + + // Check if hook requested abort + if ctx.Abort { + logger.Warn("Resolvespec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage) + return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage) + } + } + + return nil +} + +// Clear removes all hooks for the specified type +func (r *HookRegistry) Clear(hookType HookType) { + delete(r.hooks, hookType) + logger.Info("Cleared all resolvespec hooks for %s", hookType) +} + +// ClearAll removes all registered hooks +func (r *HookRegistry) ClearAll() { + r.hooks = make(map[HookType][]HookFunc) + logger.Info("Cleared all resolvespec hooks") +} + +// Count returns the number of hooks registered for a specific type +func (r *HookRegistry) Count(hookType HookType) int { + if hooks, exists := r.hooks[hookType]; exists { + return len(hooks) + } + return 0 +} + +// HasHooks returns true if there are any hooks registered for the specified type +func (r *HookRegistry) HasHooks(hookType HookType) bool { + return r.Count(hookType) > 0 +} + +// GetAllHookTypes returns all hook types that have registered hooks +func (r *HookRegistry) GetAllHookTypes() []HookType { + types := make([]HookType, 0, len(r.hooks)) + for hookType := range r.hooks { + types = append(types, hookType) + } + return types +} diff --git a/pkg/resolvespec/security_hooks.go b/pkg/resolvespec/security_hooks.go new file mode 100644 index 0000000..629c8e3 --- /dev/null +++ b/pkg/resolvespec/security_hooks.go @@ -0,0 +1,85 @@ +package resolvespec + +import ( + "context" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// RegisterSecurityHooks registers all security-related hooks with the handler +func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) { + // Hook 1: BeforeRead - Load security rules + handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.LoadSecurityRules(secCtx, securityList) + }) + + // Hook 2: BeforeScan - Apply row-level security filters + handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.ApplyRowSecurity(secCtx, securityList) + }) + + // Hook 3: AfterRead - Apply column-level security (masking) + handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.ApplyColumnSecurity(secCtx, securityList) + }) + + // Hook 4 (Optional): Audit logging + handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.LogDataAccess(secCtx) + }) + + logger.Info("Security hooks registered for resolvespec handler") +} + +// securityContext adapts resolvespec.HookContext to security.SecurityContext interface +type securityContext struct { + ctx *HookContext +} + +func newSecurityContext(ctx *HookContext) security.SecurityContext { + return &securityContext{ctx: ctx} +} + +func (s *securityContext) GetContext() context.Context { + return s.ctx.Context +} + +func (s *securityContext) GetUserID() (int, bool) { + return security.GetUserID(s.ctx.Context) +} + +func (s *securityContext) GetSchema() string { + return s.ctx.Schema +} + +func (s *securityContext) GetEntity() string { + return s.ctx.Entity +} + +func (s *securityContext) GetModel() interface{} { + return s.ctx.Model +} + +func (s *securityContext) GetQuery() interface{} { + return s.ctx.Query +} + +func (s *securityContext) SetQuery(query interface{}) { + if q, ok := query.(common.SelectQuery); ok { + s.ctx.Query = q + } +} + +func (s *securityContext) GetResult() interface{} { + return s.ctx.Result +} + +func (s *securityContext) SetResult(result interface{}) { + s.ctx.Result = result +} diff --git a/pkg/restheadspec/security_hooks.go b/pkg/restheadspec/security_hooks.go new file mode 100644 index 0000000..62b5663 --- /dev/null +++ b/pkg/restheadspec/security_hooks.go @@ -0,0 +1,82 @@ +package restheadspec + +import ( + "context" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// RegisterSecurityHooks registers all security-related hooks with the handler +func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) { + // Hook 1: BeforeRead - Load security rules + handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.LoadSecurityRules(secCtx, securityList) + }) + + // Hook 2: BeforeScan - Apply row-level security filters + handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.ApplyRowSecurity(secCtx, securityList) + }) + + // Hook 3: AfterRead - Apply column-level security (masking) + handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.ApplyColumnSecurity(secCtx, securityList) + }) + + // Hook 4 (Optional): Audit logging + handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error { + secCtx := newSecurityContext(hookCtx) + return security.LogDataAccess(secCtx) + }) + + logger.Info("Security hooks registered for restheadspec handler") +} + +// securityContext adapts restheadspec.HookContext to security.SecurityContext interface +type securityContext struct { + ctx *HookContext +} + +func newSecurityContext(ctx *HookContext) security.SecurityContext { + return &securityContext{ctx: ctx} +} + +func (s *securityContext) GetContext() context.Context { + return s.ctx.Context +} + +func (s *securityContext) GetUserID() (int, bool) { + return security.GetUserID(s.ctx.Context) +} + +func (s *securityContext) GetSchema() string { + return s.ctx.Schema +} + +func (s *securityContext) GetEntity() string { + return s.ctx.Entity +} + +func (s *securityContext) GetModel() interface{} { + return s.ctx.Model +} + +func (s *securityContext) GetQuery() interface{} { + return s.ctx.Query +} + +func (s *securityContext) SetQuery(query interface{}) { + s.ctx.Query = query +} + +func (s *securityContext) GetResult() interface{} { + return s.ctx.Result +} + +func (s *securityContext) SetResult(result interface{}) { + s.ctx.Result = result +} diff --git a/pkg/security/README.md b/pkg/security/README.md index 8ad2e12..46c349d 100644 --- a/pkg/security/README.md +++ b/pkg/security/README.md @@ -56,9 +56,10 @@ rowSec := security.NewDatabaseRowSecurityProvider(db) // 2. Combine providers provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) -// 3. Setup security +// 3. Create handler and register security hooks handler := restheadspec.NewHandlerWithGORM(db) -securityList := security.SetupSecurityProvider(handler, provider) +securityList := security.NewSecurityList(provider) +restheadspec.RegisterSecurityHooks(handler, securityList) // 4. Apply middleware router := mux.NewRouter() @@ -69,6 +70,38 @@ router.Use(security.SetSecurityMiddleware(securityList)) ## Architecture +### Spec-Agnostic Design + +The security system is **completely spec-agnostic** - it doesn't depend on any specific spec implementation. Instead, each spec (restheadspec, funcspec, resolvespec) implements its own security integration by adapting to the `SecurityContext` interface. + +``` +┌─────────────────────────────────────┐ +│ Security Package (Generic) │ +│ - SecurityContext interface │ +│ - Security providers │ +│ - Core security logic │ +└─────────────────────────────────────┘ + ▲ ▲ ▲ + │ │ │ + ┌──────┘ │ └──────┐ + │ │ │ +┌───▼────┐ ┌────▼─────┐ ┌────▼──────┐ +│RestHead│ │ FuncSpec │ │ResolveSpec│ +│ Spec │ │ │ │ │ +│ │ │ │ │ │ +│Adapts │ │ Adapts │ │ Adapts │ +│to │ │ to │ │ to │ +│Security│ │ Security │ │ Security │ +│Context │ │ Context │ │ Context │ +└────────┘ └──────────┘ └───────────┘ +``` + +**Benefits:** +- ✅ No circular dependencies +- ✅ Each spec can customize security integration +- ✅ Easy to add new specs +- ✅ Security logic is reusable across all specs + ### Core Interfaces The security system is built on three main interfaces: @@ -113,6 +146,28 @@ type SecurityProvider interface { } ``` +#### 4. SecurityContext (Spec Integration Interface) +Each spec implements this interface to integrate with the security system: + +```go +type SecurityContext interface { + GetContext() context.Context + GetUserID() (int, bool) + GetSchema() string + GetEntity() string + GetModel() interface{} + GetQuery() interface{} + SetQuery(interface{}) + GetResult() interface{} + SetResult(interface{}) +} +``` + +**Implementation Examples:** +- `restheadspec`: Adapts `restheadspec.HookContext` → `SecurityContext` +- `funcspec`: Adapts `funcspec.HookContext` → `SecurityContext` +- `resolvespec`: Adapts `resolvespec.HookContext` → `SecurityContext` + ### UserContext Enhanced user context with complete user information: @@ -197,7 +252,7 @@ rowSec := security.NewConfigRowSecurityProvider(templates, blocked) ## Usage Examples -### Example 1: Complete Database-Backed Security with Sessions +### Example 1: Complete Database-Backed Security with Sessions (restheadspec) ```go func main() { @@ -207,16 +262,20 @@ func main() { // db.Exec("CREATE TABLE users ...") // db.Exec("CREATE TABLE user_sessions ...") + // Create handler handler := restheadspec.NewHandlerWithGORM(db) - // Create providers + // Create security providers auth := security.NewDatabaseAuthenticator(db) // Session-based auth colSec := security.NewDatabaseColumnSecurityProvider(db) rowSec := security.NewDatabaseRowSecurityProvider(db) - // Combine + // Combine providers provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) - securityList := security.SetupSecurityProvider(handler, provider) + securityList := security.NewSecurityList(provider) + + // Register security hooks for this spec + restheadspec.RegisterSecurityHooks(handler, securityList) // Setup routes router := mux.NewRouter() @@ -309,14 +368,85 @@ func main() { colSec := security.NewConfigColumnSecurityProvider(columnRules) rowSec := security.NewConfigRowSecurityProvider(rowTemplates, nil) + // Combine providers and register hooks provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) - securityList := security.SetupSecurityProvider(handler, provider) + securityList := security.NewSecurityList(provider) + restheadspec.RegisterSecurityHooks(handler, securityList) // Setup routes... } ``` -### Example 3: Custom Provider +### Example 3: FuncSpec Security (SQL Query API) + +```go +import ( + "github.com/bitechdev/ResolveSpec/pkg/funcspec" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +func main() { + db := setupDatabase() + + // Create funcspec handler + handler := funcspec.NewHandler(db) + + // Create security providers + auth := security.NewJWTAuthenticator("secret-key", db) + colSec := security.NewDatabaseColumnSecurityProvider(db) + rowSec := security.NewDatabaseRowSecurityProvider(db) + + // Combine providers + provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := security.NewSecurityList(provider) + + // Register security hooks (audit logging) + funcspec.RegisterSecurityHooks(handler, securityList) + + // Note: funcspec operates on raw SQL queries, so row/column + // security is limited. Security should be enforced at the + // SQL function level or via database policies. + + // Setup routes... +} +``` + +### Example 4: ResolveSpec Security (REST API) + +```go +import ( + "github.com/bitechdev/ResolveSpec/pkg/resolvespec" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +func main() { + db := setupDatabase() + registry := common.NewModelRegistry() + + // Register models + registry.RegisterModel("public.users", &User{}) + registry.RegisterModel("public.orders", &Order{}) + + // Create resolvespec handler + handler := resolvespec.NewHandler(db, registry) + + // Create security providers + auth := security.NewDatabaseAuthenticator(db) + colSec := security.NewDatabaseColumnSecurityProvider(db) + rowSec := security.NewDatabaseRowSecurityProvider(db) + + // Combine providers + provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := security.NewSecurityList(provider) + + // Register security hooks for resolvespec + resolvespec.RegisterSecurityHooks(handler, securityList) + + // Setup routes... +} +``` + +### Example 5: Custom Provider Implement your own provider for complete control: @@ -345,9 +475,18 @@ func (p *MySecurityProvider) GetRowSecurity(ctx context.Context, userID int, sch // Your custom row security logic } -// Use it +// Use it with any spec provider := &MySecurityProvider{db: db} -securityList := security.SetupSecurityProvider(handler, provider) +securityList := security.NewSecurityList(provider) + +// Register with restheadspec +restheadspec.RegisterSecurityHooks(restHandler, securityList) + +// Or with funcspec +funcspec.RegisterSecurityHooks(funcHandler, securityList) + +// Or with resolvespec +resolvespec.RegisterSecurityHooks(resolveHandler, securityList) ``` ## Security Features @@ -419,30 +558,45 @@ securityList := security.SetupSecurityProvider(handler, provider) ``` HTTP Request ↓ -NewAuthMiddleware +NewAuthMiddleware (security package) ├─ Calls provider.Authenticate(request) └─ Adds UserContext to context ↓ -SetSecurityMiddleware +SetSecurityMiddleware (security package) └─ Adds SecurityList to context ↓ -Handler.Handle() +Spec Handler (restheadspec/funcspec/resolvespec) ↓ -BeforeRead Hook - ├─ Calls provider.GetColumnSecurity() - └─ Calls provider.GetRowSecurity() +BeforeRead Hook (registered by spec) + ├─ Adapts spec's HookContext → SecurityContext + ├─ Calls security.LoadSecurityRules(secCtx, securityList) + │ ├─ Calls provider.GetColumnSecurity() + │ └─ Calls provider.GetRowSecurity() + └─ Caches security rules ↓ -BeforeScan Hook - └─ Applies row security (adds WHERE clause) +BeforeScan Hook (registered by spec) + ├─ Adapts spec's HookContext → SecurityContext + ├─ Calls security.ApplyRowSecurity(secCtx, securityList) + └─ Applies row security (adds WHERE clause to query) ↓ Database Query (with security filters) ↓ -AfterRead Hook - └─ Applies column security (masks/hides fields) +AfterRead Hook (registered by spec) + ├─ Adapts spec's HookContext → SecurityContext + ├─ Calls security.ApplyColumnSecurity(secCtx, securityList) + ├─ Applies column security (masks/hides fields) + └─ Calls security.LogDataAccess(secCtx) ↓ HTTP Response (secured data) ``` +**Key Points:** +- Security package is spec-agnostic and provides core logic +- Each spec registers its own hooks that adapt to SecurityContext +- Security rules are loaded once and cached for the request +- Row security is applied to the query (database level) +- Column security is applied to results (application level) + ## Testing The interface-based design makes testing straightforward: @@ -475,7 +629,9 @@ func TestMyHandler(t *testing.T) { } ``` -## Migration from Callbacks +## Migration Guide + +### From Old Callback System If you're upgrading from the old callback-based system: @@ -489,7 +645,7 @@ security.SetupSecurityProvider(handler, &security.GlobalSecurity) **New:** ```go -// Wrap your functions in a provider +// 1. Wrap your functions in a provider type MyProvider struct{} func (p *MyProvider) Authenticate(r *http.Request) (*security.UserContext, error) { @@ -513,11 +669,34 @@ func (p *MyProvider) Logout(ctx context.Context, req security.LogoutRequest) err return nil } -// Use it +// 2. Create security list and register hooks provider := &MyProvider{} +securityList := security.NewSecurityList(provider) + +// 3. Register with your spec +restheadspec.RegisterSecurityHooks(handler, securityList) +``` + +### From Old SetupSecurityProvider API + +If you're upgrading from the previous interface-based system: + +**Old:** +```go securityList := security.SetupSecurityProvider(handler, provider) ``` +**New:** +```go +securityList := security.NewSecurityList(provider) +restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec +``` + +The main changes: +1. Security package no longer knows about specific spec types +2. Each spec registers its own security hooks +3. More flexible - same security provider works with all specs + ## Documentation | File | Description | diff --git a/pkg/security/hooks.go b/pkg/security/hooks.go index 6592875..20ad259 100644 --- a/pkg/security/hooks.go +++ b/pkg/security/hooks.go @@ -1,51 +1,43 @@ package security import ( + "context" "fmt" "reflect" "github.com/bitechdev/ResolveSpec/pkg/logger" - "github.com/bitechdev/ResolveSpec/pkg/restheadspec" ) -// RegisterSecurityHooks registers all security-related hooks with the handler -func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) { - - // Hook 1: BeforeRead - Load security rules - handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error { - return LoadSecurityRules(hookCtx, securityList) - }) - - // Hook 2: BeforeScan - Apply row-level security filters - handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error { - return ApplyRowSecurity(hookCtx, securityList) - }) - - // Hook 3: AfterRead - Apply column-level security (masking) - handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error { - return ApplyColumnSecurity(hookCtx, securityList) - }) - - // Hook 4 (Optional): Audit logging - handler.Hooks().Register(restheadspec.AfterRead, LogDataAccess) +// SecurityContext is a generic interface that any spec can implement to integrate with security features +// This interface abstracts the common security context needs across different specs +type SecurityContext interface { + GetContext() context.Context + GetUserID() (int, bool) + GetSchema() string + GetEntity() string + GetModel() interface{} + GetQuery() interface{} + SetQuery(interface{}) + GetResult() interface{} + SetResult(interface{}) } -// LoadSecurityRules loads security configuration for the user and entity -func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { +// loadSecurityRules loads security configuration for the user and entity (generic version) +func loadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error { // Extract user ID from context - userID, ok := GetUserID(hookCtx.Context) + userID, ok := secCtx.GetUserID() if !ok { logger.Warn("No user ID in context for security check") return fmt.Errorf("authentication required") } - schema := hookCtx.Schema - tablename := hookCtx.Entity + schema := secCtx.GetSchema() + tablename := secCtx.GetEntity() logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename) // Load column security rules using the provider - err := securityList.LoadColumnSecurity(hookCtx.Context, userID, schema, tablename, false) + err := securityList.LoadColumnSecurity(secCtx.GetContext(), userID, schema, tablename, false) if err != nil { logger.Warn("Failed to load column security: %v", err) // Don't fail the request if no security rules exist @@ -53,7 +45,7 @@ func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security } // Load row security rules using the provider - _, err = securityList.LoadRowSecurity(hookCtx.Context, userID, schema, tablename, false) + _, err = securityList.LoadRowSecurity(secCtx.GetContext(), userID, schema, tablename, false) if err != nil { logger.Warn("Failed to load row security: %v", err) // Don't fail the request if no security rules exist @@ -63,15 +55,15 @@ func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security return nil } -// ApplyRowSecurity applies row-level security filters to the query -func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { - userID, ok := GetUserID(hookCtx.Context) +// applyRowSecurity applies row-level security filters to the query (generic version) +func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error { + userID, ok := secCtx.GetUserID() if !ok { return nil // No user context, skip } - schema := hookCtx.Schema - tablename := hookCtx.Entity + schema := secCtx.GetSchema() + tablename := secCtx.GetEntity() // Get row security template rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename) @@ -89,8 +81,14 @@ func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL // If there's a security template, apply it as a WHERE clause if rowSec.Template != "" { + model := secCtx.GetModel() + if model == nil { + logger.Debug("No model available for row security on %s.%s", schema, tablename) + return nil + } + // Get primary key name from model - modelType := reflect.TypeOf(hookCtx.Model) + modelType := reflect.TypeOf(model) if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } @@ -117,39 +115,45 @@ func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL userID, schema, tablename, whereClause) // Apply the WHERE clause to the query - // The query is in hookCtx.Query - if selectQuery, ok := hookCtx.Query.(interface { + query := secCtx.GetQuery() + if selectQuery, ok := query.(interface { Where(string, ...interface{}) interface{} }); ok { - hookCtx.Query = selectQuery.Where(whereClause) + secCtx.SetQuery(selectQuery.Where(whereClause)) } else { - logger.Error("Unable to apply WHERE clause - query doesn't support Where method") + logger.Debug("Query doesn't support Where method, skipping row security") } } return nil } -// ApplyColumnSecurity applies column-level security (masking/hiding) to results -func ApplyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { - userID, ok := GetUserID(hookCtx.Context) +// applyColumnSecurity applies column-level security (masking/hiding) to results (generic version) +func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error { + userID, ok := secCtx.GetUserID() if !ok { return nil // No user context, skip } - schema := hookCtx.Schema - tablename := hookCtx.Entity + schema := secCtx.GetSchema() + tablename := secCtx.GetEntity() // Get result data - result := hookCtx.Result + result := secCtx.GetResult() if result == nil { return nil } logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename) + model := secCtx.GetModel() + if model == nil { + logger.Debug("No model available for column security on %s.%s", schema, tablename) + return nil + } + // Get model type - modelType := reflect.TypeOf(hookCtx.Model) + modelType := reflect.TypeOf(model) if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } @@ -169,37 +173,59 @@ func ApplyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *Securi // Update the result with masked data if maskedResult.IsValid() && maskedResult.CanInterface() { - hookCtx.Result = maskedResult.Interface() + secCtx.SetResult(maskedResult.Interface()) } return nil } -// LogDataAccess logs all data access for audit purposes -func LogDataAccess(hookCtx *restheadspec.HookContext) error { - userID, _ := GetUserID(hookCtx.Context) +// logDataAccess logs all data access for audit purposes (generic version) +func logDataAccess(secCtx SecurityContext) error { + userID, _ := secCtx.GetUserID() - logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v", + logger.Info("AUDIT: User %d accessed %s.%s", userID, - hookCtx.Schema, - hookCtx.Entity, - hookCtx.Options.Filters, + secCtx.GetSchema(), + secCtx.GetEntity(), ) // TODO: Write to audit log table or external audit service // auditLog := AuditLog{ // UserID: userID, - // Schema: hookCtx.Schema, - // Entity: hookCtx.Entity, + // Schema: secCtx.GetSchema(), + // Entity: secCtx.GetEntity(), // Action: "READ", // Timestamp: time.Now(), - // Filters: hookCtx.Options.Filters, // } // db.Create(&auditLog) return nil } +// LogDataAccess is a public wrapper for logDataAccess that accepts a SecurityContext +// This allows other packages to use the audit logging functionality +func LogDataAccess(secCtx SecurityContext) error { + return logDataAccess(secCtx) +} + +// LoadSecurityRules is a public wrapper for loadSecurityRules that accepts a SecurityContext +// This allows other packages to load security rules using the generic interface +func LoadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error { + return loadSecurityRules(secCtx, securityList) +} + +// ApplyRowSecurity is a public wrapper for applyRowSecurity that accepts a SecurityContext +// This allows other packages to apply row-level security using the generic interface +func ApplyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error { + return applyRowSecurity(secCtx, securityList) +} + +// ApplyColumnSecurity is a public wrapper for applyColumnSecurity that accepts a SecurityContext +// This allows other packages to apply column-level security using the generic interface +func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error { + return applyColumnSecurity(secCtx, securityList) +} + // Helper functions func contains(s, substr string) bool { diff --git a/pkg/security/setup_example.go b/pkg/security/setup_example.go deleted file mode 100644 index a5ded81..0000000 --- a/pkg/security/setup_example.go +++ /dev/null @@ -1,300 +0,0 @@ -package security - -import ( - "context" - "database/sql" - "fmt" - "net/http" - - "github.com/gorilla/mux" - - "github.com/bitechdev/ResolveSpec/pkg/restheadspec" -) - -// SetupSecurityProvider initializes and configures the security provider -// This function creates a SecurityList with the given provider and registers hooks -// -// Example usage: -// -// // Create your security provider (use composite or single provider) -// auth := security.NewJWTAuthenticator("your-secret-key", db) -// colSec := security.NewDatabaseColumnSecurityProvider(db) -// rowSec := security.NewDatabaseRowSecurityProvider(db) -// provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) -// -// // Setup security with the provider -// handler := restheadspec.NewHandlerWithGORM(db) -// securityList := security.SetupSecurityProvider(handler, provider) -// -// // Apply middleware -// router.Use(security.NewAuthMiddleware(securityList)) -// router.Use(security.SetSecurityMiddleware(securityList)) -func SetupSecurityProvider(handler *restheadspec.Handler, provider SecurityProvider) *SecurityList { - if provider == nil { - panic("security provider cannot be nil") - } - - // Create security list with the provider - securityList := NewSecurityList(provider) - - // Register all security hooks - RegisterSecurityHooks(handler, securityList) - - return securityList -} - -// Example 1: Complete Setup with Composite Provider and Database-Backed Security -// =============================================================================== -// Note: Security providers use *sql.DB, but restheadspec.Handler may use *gorm.DB -// You can get *sql.DB from gorm.DB using: sqlDB, _ := gormDB.DB() - -func ExampleDatabaseSecurity(gormDB interface{}, sqlDB *sql.DB) (http.Handler, error) { - // Step 1: Create the ResolveSpec handler - // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) - handler := &restheadspec.Handler{} // Placeholder - use your handler initialization - - // Step 2: Register your models - // handler.RegisterModel("public", "users", User{}) - // handler.RegisterModel("public", "orders", Order{}) - - // Step 3: Create security provider components (using sql.DB) - auth := NewJWTAuthenticator("your-secret-key", sqlDB) - colSec := NewDatabaseColumnSecurityProvider(sqlDB) - rowSec := NewDatabaseRowSecurityProvider(sqlDB) - - // Step 4: Combine into composite provider - provider := NewCompositeSecurityProvider(auth, colSec, rowSec) - - // Step 5: Setup security - securityList := SetupSecurityProvider(handler, provider) - - // Step 6: Create router and setup routes with authentication - router := mux.NewRouter() - authMiddleware := func(h http.Handler) http.Handler { - return NewAuthHandler(securityList, h) - } - restheadspec.SetupMuxRoutes(router, handler, authMiddleware) - - // Step 7: Apply additional security middleware - router.Use(SetSecurityMiddleware(securityList)) - - return router, nil -} - -// Example 2: Simple Header-Based Authentication -// ============================================== - -func ExampleHeaderAuthentication(gormDB interface{}, sqlDB *sql.DB) (*mux.Router, error) { - // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) - handler := &restheadspec.Handler{} // Placeholder - use your handler initialization - - // Use header-based auth with database security providers - auth := NewHeaderAuthenticatorExample() - colSec := NewDatabaseColumnSecurityProvider(sqlDB) - rowSec := NewDatabaseRowSecurityProvider(sqlDB) - - provider := NewCompositeSecurityProvider(auth, colSec, rowSec) - securityList := SetupSecurityProvider(handler, provider) - - router := mux.NewRouter() - authMiddleware := func(h http.Handler) http.Handler { - return NewAuthHandler(securityList, h) - } - restheadspec.SetupMuxRoutes(router, handler, authMiddleware) - - router.Use(SetSecurityMiddleware(securityList)) - - return router, nil -} - -// Example 3: Config-Based Security (No Database for Security) -// =========================================================== - -func ExampleConfigSecurity(gormDB interface{}) (*mux.Router, error) { - // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) - handler := &restheadspec.Handler{} // Placeholder - use your handler initialization - - // Define column security rules in code - columnRules := map[string][]ColumnSecurity{ - "public.employees": { - { - Schema: "public", - Tablename: "employees", - Path: []string{"ssn"}, - Accesstype: "mask", - MaskStart: 5, - MaskChar: "*", - }, - { - Schema: "public", - Tablename: "employees", - Path: []string{"salary"}, - Accesstype: "hide", - }, - }, - } - - // Define row security templates - rowTemplates := map[string]string{ - "public.orders": "user_id = {UserID}", - "public.documents": "user_id = {UserID} OR is_public = true", - } - - // Define blocked tables - blockedTables := map[string]bool{ - "public.admin_logs": true, - } - - // Create providers - auth := NewHeaderAuthenticatorExample() - colSec := NewConfigColumnSecurityProvider(columnRules) - rowSec := NewConfigRowSecurityProvider(rowTemplates, blockedTables) - - provider := NewCompositeSecurityProvider(auth, colSec, rowSec) - securityList := SetupSecurityProvider(handler, provider) - - router := mux.NewRouter() - authMiddleware := func(h http.Handler) http.Handler { - return NewAuthHandler(securityList, h) - } - restheadspec.SetupMuxRoutes(router, handler, authMiddleware) - - router.Use(SetSecurityMiddleware(securityList)) - - return router, nil -} - -// Example 4: Custom Security Provider -// ==================================== - -// You can implement your own SecurityProvider by implementing all three interfaces -type CustomSecurityProvider struct { - // Your custom fields -} - -func (p *CustomSecurityProvider) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { - // Your custom login logic - return nil, fmt.Errorf("not implemented") -} - -func (p *CustomSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error { - // Your custom logout logic - return nil -} - -func (p *CustomSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) { - // Your custom authentication logic - return nil, fmt.Errorf("not implemented") -} - -func (p *CustomSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { - // Your custom column security logic - return []ColumnSecurity{}, nil -} - -func (p *CustomSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { - // Your custom row security logic - return RowSecurity{ - Schema: schema, - Tablename: table, - UserID: userID, - }, nil -} - -// Example 5: Adding Login/Logout Endpoints -// ========================================= - -func SetupAuthRoutes(router *mux.Router, securityList *SecurityList) { - // Login endpoint - router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) { - // Parse login request - var loginReq LoginRequest - // json.NewDecoder(r.Body).Decode(&loginReq) - - // Call provider's Login method - resp, err := securityList.Provider().Login(r.Context(), loginReq) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - - // Return token - w.Header().Set("Content-Type", "application/json") - // json.NewEncoder(w).Encode(resp) - fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn) - }).Methods("POST") - - // Logout endpoint - router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) { - // Extract token from header - token := r.Header.Get("Authorization") - - // Get user ID from context (if authenticated) - userID, _ := GetUserID(r.Context()) - - // Call provider's Logout method - err := securityList.Provider().Logout(r.Context(), LogoutRequest{ - Token: token, - UserID: userID, - }) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, `{"success": true}`) - }).Methods("POST") - - // Optional: Token refresh endpoint - router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { - refreshToken := r.Header.Get("X-Refresh-Token") - - // Check if provider supports refresh - if refreshable, ok := securityList.Provider().(Refreshable); ok { - resp, err := refreshable.RefreshToken(r.Context(), refreshToken) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn) - } else { - http.Error(w, "Token refresh not supported", http.StatusNotImplemented) - } - }).Methods("POST") -} - -// Example 6: Complete Server Setup -// ================================= - -func CompleteServerExample(gormDB interface{}, sqlDB *sql.DB) http.Handler { - // Create handler and register models - // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) - handler := &restheadspec.Handler{} // Placeholder - use your handler initialization - // handler.RegisterModel("public", "users", User{}) - - // Setup security (using sql.DB for security providers) - auth := NewJWTAuthenticator("secret-key", sqlDB) - colSec := NewDatabaseColumnSecurityProvider(sqlDB) - rowSec := NewDatabaseRowSecurityProvider(sqlDB) - provider := NewCompositeSecurityProvider(auth, colSec, rowSec) - securityList := SetupSecurityProvider(handler, provider) - - // Create router - router := mux.NewRouter() - - // Add auth routes (login/logout) - SetupAuthRoutes(router, securityList) - - // Add API routes with authentication - apiRouter := router.PathPrefix("/api").Subrouter() - authMiddleware := func(h http.Handler) http.Handler { - return NewAuthHandler(securityList, h) - } - restheadspec.SetupMuxRoutes(apiRouter, handler, authMiddleware) - apiRouter.Use(SetSecurityMiddleware(securityList)) - - return router -}