diff --git a/pkg/security/CALLBACKS_GUIDE.md b/pkg/security/CALLBACKS_GUIDE.md deleted file mode 100644 index 2170277..0000000 --- a/pkg/security/CALLBACKS_GUIDE.md +++ /dev/null @@ -1,662 +0,0 @@ -# Security Provider Callbacks Guide - -## Overview - -The ResolveSpec security provider uses a **callback-based architecture** that requires you to implement three functions: - -1. **AuthenticateCallback** - Extract user credentials from HTTP requests -2. **LoadColumnSecurityCallback** - Load column security rules for masking/hiding -3. **LoadRowSecurityCallback** - Load row security filters (WHERE clauses) - -This design allows you to integrate the security provider with **any** authentication system and database schema. - ---- - -## Why Callbacks? - -The callback-based design provides: - -✅ **Flexibility** - Works with any auth system (JWT, session, OAuth, custom) -✅ **Database Agnostic** - No assumptions about your security table schema -✅ **Testability** - Easy to mock for unit tests -✅ **Extensibility** - Add custom logic without modifying core code - ---- - -## Quick Start - -### Step 1: Implement the Three Callbacks - -```go -package main - -import ( - "fmt" - "net/http" - "github.com/bitechdev/ResolveSpec/pkg/security" -) - -// 1. Authentication: Extract user from request -func myAuthFunction(r *http.Request) (userID int, roles string, err error) { - // Your auth logic here (JWT, session, header, etc.) - token := r.Header.Get("Authorization") - userID, roles, err = validateToken(token) - return userID, roles, err -} - -// 2. Column Security: Load column masking rules -func myLoadColumnSecurity(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { - // Your database query or config lookup here - return loadColumnRulesFromDatabase(userID, schema, tablename) -} - -// 3. Row Security: Load row filtering rules -func myLoadRowSecurity(userID int, schema, tablename string) (security.RowSecurity, error) { - // Your database query or config lookup here - return loadRowRulesFromDatabase(userID, schema, tablename) -} -``` - -### Step 2: Configure the Callbacks - -```go -func main() { - db := setupDatabase() - handler := restheadspec.NewHandlerWithGORM(db) - - // Configure callbacks BEFORE SetupSecurityProvider - security.GlobalSecurity.AuthenticateCallback = myAuthFunction - security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurity - security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurity - - // Setup security provider (validates callbacks are set) - if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { - log.Fatal(err) // Fails if callbacks not configured - } - - // Apply middleware - router := mux.NewRouter() - restheadspec.SetupMuxRoutes(router, handler) - router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) - router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) - - http.ListenAndServe(":8080", router) -} -``` - ---- - -## Callback 1: AuthenticateCallback - -### Function Signature - -```go -func(r *http.Request) (userID int, roles string, err error) -``` - -### Parameters -- `r *http.Request` - The incoming HTTP request - -### Returns -- `userID int` - The authenticated user's ID -- `roles string` - User's roles (comma-separated, e.g., "admin,manager") -- `err error` - Return error to reject the request (HTTP 401) - -### Example Implementations - -#### Simple Header-Based Auth -```go -func authenticateFromHeader(r *http.Request) (int, string, error) { - userIDStr := r.Header.Get("X-User-ID") - if userIDStr == "" { - return 0, "", fmt.Errorf("X-User-ID header required") - } - - userID, err := strconv.Atoi(userIDStr) - if err != nil { - return 0, "", fmt.Errorf("invalid user ID") - } - - roles := r.Header.Get("X-User-Roles") // Optional - return userID, roles, nil -} -``` - -#### JWT Token Auth -```go -import "github.com/golang-jwt/jwt/v5" - -func authenticateFromJWT(r *http.Request) (int, string, error) { - authHeader := r.Header.Get("Authorization") - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte(os.Getenv("JWT_SECRET")), nil - }) - - if err != nil || !token.Valid { - return 0, "", fmt.Errorf("invalid token") - } - - claims := token.Claims.(jwt.MapClaims) - userID := int(claims["user_id"].(float64)) - roles := claims["roles"].(string) - - return userID, roles, nil -} -``` - -#### Session Cookie Auth -```go -func authenticateFromSession(r *http.Request) (int, string, error) { - cookie, err := r.Cookie("session_id") - if err != nil { - return 0, "", fmt.Errorf("no session cookie") - } - - session, err := sessionStore.Get(cookie.Value) - if err != nil { - return 0, "", fmt.Errorf("invalid session") - } - - return session.UserID, session.Roles, nil -} -``` - ---- - -## Callback 2: LoadColumnSecurityCallback - -### Function Signature - -```go -func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) -``` - -### Parameters -- `pUserID int` - The authenticated user's ID -- `pSchema string` - Database schema (e.g., "public") -- `pTablename string` - Table name (e.g., "employees") - -### Returns -- `[]ColumnSecurity` - List of column security rules -- `error` - Return error if loading fails - -### ColumnSecurity Structure - -```go -type ColumnSecurity struct { - Schema string // "public" - Tablename string // "employees" - Path []string // ["ssn"] or ["address", "street"] - Accesstype string // "mask" or "hide" - - // Masking configuration (for Accesstype = "mask") - MaskStart int // Mask first N characters - MaskEnd int // Mask last N characters - MaskInvert bool // true = mask middle, false = mask edges - MaskChar string // Character to use for masking (default "*") - - // Optional fields - ExtraFilters map[string]string - Control string - ID int - UserID int -} -``` - -### Example Implementations - -#### Load from Database -```go -func loadColumnSecurityFromDB(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { - var rules []security.ColumnSecurity - - query := ` - SELECT control, accesstype, jsonvalue - FROM core.secacces - WHERE rid_hub IN ( - SELECT rid_hub_parent FROM core.hub_link - WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup' - ) - AND control ILIKE ? - ` - - rows, err := db.Query(query, userID, fmt.Sprintf("%s.%s%%", schema, tablename)) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var control, accesstype, jsonValue string - rows.Scan(&control, &accesstype, &jsonValue) - - // Parse control: "schema.table.column" - parts := strings.Split(control, ".") - if len(parts) < 3 { - continue - } - - rule := security.ColumnSecurity{ - Schema: schema, - Tablename: tablename, - Path: parts[2:], - Accesstype: accesstype, - } - - // Parse JSON configuration - var config map[string]interface{} - json.Unmarshal([]byte(jsonValue), &config) - if start, ok := config["start"].(float64); ok { - rule.MaskStart = int(start) - } - if end, ok := config["end"].(float64); ok { - rule.MaskEnd = int(end) - } - if char, ok := config["char"].(string); ok { - rule.MaskChar = char - } - - rules = append(rules, rule) - } - - return rules, nil -} -``` - -#### Load from Static Config -```go -func loadColumnSecurityFromConfig(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { - // Define security rules in code - allRules := map[string][]security.ColumnSecurity{ - "public.employees": { - { - Schema: "public", - Tablename: "employees", - Path: []string{"ssn"}, - Accesstype: "mask", - MaskStart: 5, - MaskChar: "*", - }, - { - Schema: "public", - Tablename: "employees", - Path: []string{"salary"}, - Accesstype: "hide", - }, - }, - } - - key := fmt.Sprintf("%s.%s", schema, tablename) - rules, ok := allRules[key] - if !ok { - return []security.ColumnSecurity{}, nil // No rules - } - - return rules, nil -} -``` - -### Column Security Examples - -**Mask SSN (show last 4 digits):** -```go -ColumnSecurity{ - Path: []string{"ssn"}, - Accesstype: "mask", - MaskStart: 5, // Mask first 5 characters - MaskEnd: 0, // Keep last 4 visible - MaskChar: "*", -} -// Result: "123-45-6789" → "*****6789" -``` - -**Hide entire field:** -```go -ColumnSecurity{ - Path: []string{"salary"}, - Accesstype: "hide", -} -// Result: salary field returns 0 or empty -``` - -**Mask credit card (show last 4 digits):** -```go -ColumnSecurity{ - Path: []string{"credit_card"}, - Accesstype: "mask", - MaskStart: 12, - MaskChar: "*", -} -// Result: "1234-5678-9012-3456" → "************3456" -``` - ---- - -## Callback 3: LoadRowSecurityCallback - -### Function Signature - -```go -func(pUserID int, pSchema, pTablename string) (RowSecurity, error) -``` - -### Parameters -- `pUserID int` - The authenticated user's ID -- `pSchema string` - Database schema -- `pTablename string` - Table name - -### Returns -- `RowSecurity` - Row security configuration -- `error` - Return error if loading fails - -### RowSecurity Structure - -```go -type RowSecurity struct { - Schema string // "public" - Tablename string // "orders" - UserID int // Current user ID - Template string // WHERE clause template (e.g., "user_id = {UserID}") - HasBlock bool // If true, block ALL access to this table -} -``` - -### Template Variables - -You can use these placeholders in the `Template` string: -- `{UserID}` - Current user's ID -- `{PrimaryKeyName}` - Primary key column name -- `{TableName}` - Table name -- `{SchemaName}` - Schema name - -### Example Implementations - -#### Load from Database Function -```go -func loadRowSecurityFromDB(userID int, schema, tablename string) (security.RowSecurity, error) { - var record security.RowSecurity - - query := ` - SELECT p_template, p_block - FROM core.api_sec_rowtemplate(?, ?, ?) - ` - - row := db.QueryRow(query, schema, tablename, userID) - err := row.Scan(&record.Template, &record.HasBlock) - if err != nil { - return security.RowSecurity{}, err - } - - record.Schema = schema - record.Tablename = tablename - record.UserID = userID - - return record, nil -} -``` - -#### Load from Static Config -```go -func loadRowSecurityFromConfig(userID int, schema, tablename string) (security.RowSecurity, error) { - key := fmt.Sprintf("%s.%s", schema, tablename) - - // Define templates for each table - templates := map[string]string{ - "public.orders": "user_id = {UserID}", - "public.documents": "user_id = {UserID} OR is_public = true", - } - - // Define blocked tables - blocked := map[string]bool{ - "public.admin_logs": true, - } - - if blocked[key] { - return security.RowSecurity{ - Schema: schema, - Tablename: tablename, - UserID: userID, - HasBlock: true, - }, nil - } - - template, ok := templates[key] - if !ok { - // No row security - allow all rows - return security.RowSecurity{ - Schema: schema, - Tablename: tablename, - UserID: userID, - Template: "", - HasBlock: false, - }, nil - } - - return security.RowSecurity{ - Schema: schema, - Tablename: tablename, - UserID: userID, - Template: template, - HasBlock: false, - }, nil -} -``` - -### Row Security Examples - -**Users see only their own records:** -```go -RowSecurity{ - Template: "user_id = {UserID}", -} -// Query: SELECT * FROM orders WHERE user_id = 123 -``` - -**Users see their records OR public records:** -```go -RowSecurity{ - Template: "user_id = {UserID} OR is_public = true", -} -``` - -**Complex filter with subquery:** -```go -RowSecurity{ - Template: "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", -} -``` - -**Block all access:** -```go -RowSecurity{ - HasBlock: true, -} -// All queries to this table will be rejected -``` - ---- - -## Complete Integration Example - -```go -package main - -import ( - "fmt" - "log" - "net/http" - "strconv" - - "github.com/bitechdev/ResolveSpec/pkg/restheadspec" - "github.com/bitechdev/ResolveSpec/pkg/security" - "github.com/gorilla/mux" - "gorm.io/gorm" -) - -func main() { - db := setupDatabase() - handler := restheadspec.NewHandlerWithGORM(db) - handler.RegisterModel("public", "orders", Order{}) - - // ===== CONFIGURE CALLBACKS ===== - security.GlobalSecurity.AuthenticateCallback = authenticateUser - security.GlobalSecurity.LoadColumnSecurityCallback = loadColumnSec - security.GlobalSecurity.LoadRowSecurityCallback = loadRowSec - - // ===== SETUP SECURITY ===== - if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { - log.Fatal("Security setup failed:", err) - } - - // ===== SETUP ROUTES ===== - router := mux.NewRouter() - restheadspec.SetupMuxRoutes(router, handler) - router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) - router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) - - log.Println("Server starting on :8080") - http.ListenAndServe(":8080", router) -} - -// Callback implementations -func authenticateUser(r *http.Request) (int, string, error) { - userIDStr := r.Header.Get("X-User-ID") - if userIDStr == "" { - return 0, "", fmt.Errorf("authentication required") - } - userID, err := strconv.Atoi(userIDStr) - return userID, "", err -} - -func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - // Your implementation here - return []security.ColumnSecurity{}, nil -} - -func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { - return security.RowSecurity{ - Schema: schema, - Tablename: table, - UserID: userID, - Template: "user_id = " + strconv.Itoa(userID), - }, nil -} -``` - ---- - -## Testing Your Callbacks - -### Unit Test Example - -```go -func TestAuthCallback(t *testing.T) { - req := httptest.NewRequest("GET", "/api/orders", nil) - req.Header.Set("X-User-ID", "123") - - userID, roles, err := myAuthFunction(req) - - assert.Nil(t, err) - assert.Equal(t, 123, userID) -} - -func TestColumnSecurityCallback(t *testing.T) { - rules, err := myLoadColumnSecurity(123, "public", "employees") - - assert.Nil(t, err) - assert.Greater(t, len(rules), 0) - assert.Equal(t, "mask", rules[0].Accesstype) -} -``` - ---- - -## Common Patterns - -### Pattern 1: Role-Based Security - -```go -func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - roles := getUserRoles(userID) - - if contains(roles, "admin") { - // Admins see everything - return []security.ColumnSecurity{}, nil - } - - // Non-admins have restrictions - return []security.ColumnSecurity{ - {Path: []string{"ssn"}, Accesstype: "mask"}, - }, nil -} -``` - -### Pattern 2: Tenant Isolation - -```go -func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { - tenantID := getUserTenant(userID) - - return security.RowSecurity{ - Template: fmt.Sprintf("tenant_id = %d", tenantID), - }, nil -} -``` - -### Pattern 3: Caching Security Rules - -```go -var securityCache = cache.New(5*time.Minute, 10*time.Minute) - -func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - cacheKey := fmt.Sprintf("%d:%s.%s", userID, schema, table) - - if cached, found := securityCache.Get(cacheKey); found { - return cached.([]security.ColumnSecurity), nil - } - - rules := loadFromDatabase(userID, schema, table) - securityCache.Set(cacheKey, rules, cache.DefaultExpiration) - - return rules, nil -} -``` - ---- - -## Troubleshooting - -### Error: "AuthenticateCallback not set" -**Solution:** Configure all three callbacks before calling `SetupSecurityProvider`: -```go -security.GlobalSecurity.AuthenticateCallback = myAuthFunc -security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc -security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc -``` - -### Error: "Authentication failed" -**Solution:** Check your `AuthenticateCallback` implementation. Ensure it returns valid user ID or proper error. - -### Security rules not applying -**Solution:** -1. Check callbacks are returning data -2. Enable debug logging -3. Verify database queries return results -4. Check user has security groups assigned - ---- - -## Next Steps - -1. ✅ Implement the three callbacks for your system -2. ✅ Configure `GlobalSecurity` with your callbacks -3. ✅ Call `SetupSecurityProvider` -4. ✅ Test with different users and verify isolation -5. ✅ Review `callbacks_example.go` for more examples - -For complete working examples, see: -- `pkg/security/callbacks_example.go` - 7 example implementations -- `examples/secure_server/main.go` - Full server example -- `pkg/security/README.md` - Comprehensive documentation diff --git a/pkg/security/QUICK_REFERENCE.md b/pkg/security/QUICK_REFERENCE.md index 9530d85..b1c6d4f 100644 --- a/pkg/security/QUICK_REFERENCE.md +++ b/pkg/security/QUICK_REFERENCE.md @@ -3,35 +3,96 @@ ## 3-Step Setup ```go -// Step 1: Implement callbacks -func myAuth(r *http.Request) (int, string, error) { /* ... */ } -func myColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { /* ... */ } -func myRowSec(userID int, schema, table string) (security.RowSecurity, error) { /* ... */ } +// Step 1: Create security providers +auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended) +// OR: auth := security.NewJWTAuthenticator("secret-key", db) +// OR: auth := security.NewHeaderAuthenticator() -// Step 2: Configure callbacks -security.GlobalSecurity.AuthenticateCallback = myAuth -security.GlobalSecurity.LoadColumnSecurityCallback = myColSec -security.GlobalSecurity.LoadRowSecurityCallback = myRowSec +colSec := security.NewDatabaseColumnSecurityProvider(db) +rowSec := security.NewDatabaseRowSecurityProvider(db) + +// Step 2: Combine providers +provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) // Step 3: Setup and apply middleware -security.SetupSecurityProvider(handler, &security.GlobalSecurity) -router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) -router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) +securityList := security.SetupSecurityProvider(handler, provider) +router.Use(security.NewAuthMiddleware(securityList)) +router.Use(security.SetSecurityMiddleware(securityList)) ``` --- -## Callback Signatures +## Stored Procedures + +**All database operations use PostgreSQL stored procedures** with `resolvespec_*` naming: + +### Database Authenticators +```go +// DatabaseAuthenticator uses these stored procedures: +resolvespec_login(jsonb) // Login with credentials +resolvespec_logout(jsonb) // Invalidate session +resolvespec_session(text, text) // Validate session token +resolvespec_session_update(text, jsonb) // Update activity timestamp +resolvespec_refresh_token(text, jsonb) // Generate new session + +// JWTAuthenticator uses these stored procedures: +resolvespec_jwt_login(text, text) // Validate credentials +resolvespec_jwt_logout(text, int) // Blacklist token +``` + +### Security Providers +```go +// DatabaseColumnSecurityProvider: +resolvespec_column_security(int, text, text) // Load column rules + +// DatabaseRowSecurityProvider: +resolvespec_row_security(text, text, int) // Load row template +``` + +All stored procedures return structured results: +- Session/Login: `(p_success bool, p_error text, p_data jsonb)` +- Security: `(p_success bool, p_error text, p_rules jsonb)` + +See `database_schema.sql` for complete definitions. + +--- + +## Interface Signatures ```go -// 1. Authentication -func(r *http.Request) (userID int, roles string, err error) +// Authenticator interface +type Authenticator interface { + Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) + Logout(ctx context.Context, req LogoutRequest) error + Authenticate(r *http.Request) (*UserContext, error) +} -// 2. Column Security -func(userID int, schema, tablename string) ([]ColumnSecurity, error) +// ColumnSecurityProvider interface +type ColumnSecurityProvider interface { + GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) +} -// 3. Row Security -func(userID int, schema, tablename string) (RowSecurity, error) +// RowSecurityProvider interface +type RowSecurityProvider interface { + GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) +} +``` + +--- + +## UserContext Structure + +```go +security.UserContext{ + UserID: 123, // User's unique ID + UserName: "john_doe", // Username + UserLevel: 5, // User privilege level + SessionID: "sess_abc123", // Current session ID + RemoteID: "remote_xyz", // Remote system ID + Roles: []string{"admin"}, // User roles + Email: "john@example.com", // User email + Claims: map[string]any{}, // Additional metadata +} ``` --- @@ -109,70 +170,204 @@ HasBlock: true ## Example Implementations -### Simple Header Auth +### Database Session Authenticator (Recommended) ```go -func authFromHeader(r *http.Request) (int, string, error) { +// Create authenticator +auth := security.NewDatabaseAuthenticator(db) + +// Requires these tables: +// - users (id, username, email, password, user_level, roles, is_active) +// - user_sessions (session_token, user_id, expires_at, created_at, last_activity_at) +// See database_schema.sql for full schema + +// Features: +// - Login with username/password +// - Session management in database +// - Token refresh support (implements Refreshable) +// - Automatic session expiration +// - Tracks IP address and user agent +// - Works with Authorization header or cookie +``` + +### Simple Header Authenticator + +```go +type HeaderAuthenticator struct{} + +func NewHeaderAuthenticator() *HeaderAuthenticator { + return &HeaderAuthenticator{} +} + +func (a *HeaderAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) { + return nil, fmt.Errorf("not supported") +} + +func (a *HeaderAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { + return nil +} + +func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { userIDStr := r.Header.Get("X-User-ID") if userIDStr == "" { - return 0, "", fmt.Errorf("X-User-ID required") + return nil, fmt.Errorf("X-User-ID required") } - userID, err := strconv.Atoi(userIDStr) - return userID, "", err + userID, _ := strconv.Atoi(userIDStr) + return &security.UserContext{ + UserID: userID, + UserName: r.Header.Get("X-User-Name"), + }, nil } ``` -### JWT Auth +### JWT Authenticator ```go -func authFromJWT(r *http.Request) (int, string, error) { - token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - claims, err := jwt.Parse(token, secret) +type JWTAuthenticator struct { + secretKey []byte + db *gorm.DB +} + +func NewJWTAuthenticator(secret string, db *gorm.DB) *JWTAuthenticator { + return &JWTAuthenticator{secretKey: []byte(secret), db: db} +} + +func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) { + // Validate credentials against database + var user User + err := a.db.WithContext(ctx).Where("username = ?", req.Username).First(&user).Error if err != nil { - return 0, "", err + return nil, fmt.Errorf("invalid credentials") } - return claims.UserID, claims.Roles, nil + + // Generate JWT token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": user.ID, + "exp": time.Now().Add(24 * time.Hour).Unix(), + }) + tokenString, _ := token.SignedString(a.secretKey) + + return &security.LoginResponse{ + Token: tokenString, + User: &security.UserContext{UserID: user.ID}, + ExpiresIn: 86400, + }, nil +} + +func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { + // Add to blacklist + return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{ + "token": req.Token, + "user_id": req.UserID, + }).Error +} + +func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { + tokenString := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) { + return a.secretKey, nil + }) + if err != nil || !token.Valid { + return nil, fmt.Errorf("invalid token") + } + claims := token.Claims.(jwt.MapClaims) + return &security.UserContext{ + UserID: int(claims["user_id"].(float64)), + }, nil } ``` ### Static Column Security ```go -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - if table == "employees" { - return []security.ColumnSecurity{ - {Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}, - {Path: []string{"salary"}, Accesstype: "hide"}, - }, nil - } - return []security.ColumnSecurity{}, nil +type ConfigColumnSecurityProvider struct { + rules map[string][]security.ColumnSecurity +} + +func NewConfigColumnSecurityProvider(rules map[string][]security.ColumnSecurity) *ConfigColumnSecurityProvider { + return &ConfigColumnSecurityProvider{rules: rules} +} + +func (p *ConfigColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + key := fmt.Sprintf("%s.%s", schema, table) + return p.rules[key], nil } ``` ### Database Column Security ```go -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - rows, err := db.Query(` +type DatabaseColumnSecurityProvider struct { + db *gorm.DB +} + +func NewDatabaseColumnSecurityProvider(db *gorm.DB) *DatabaseColumnSecurityProvider { + return &DatabaseColumnSecurityProvider{db: db} +} + +func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + var records []struct { + Control string + Accesstype string + JSONValue string + } + + query := ` SELECT control, accesstype, jsonvalue - FROM core.secacces - WHERE rid_hub IN (...) + FROM core.secaccess + WHERE rid_hub IN ( + SELECT rid_hub_parent FROM core.hub_link + WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup' + ) AND control ILIKE ? - `, fmt.Sprintf("%s.%s%%", schema, table)) - // ... parse and return + ` + + err := p.db.WithContext(ctx).Raw(query, userID, fmt.Sprintf("%s.%s%%", schema, table)).Scan(&records).Error + if err != nil { + return nil, err + } + + var rules []security.ColumnSecurity + for _, rec := range records { + parts := strings.Split(rec.Control, ".") + if len(parts) < 3 { + continue + } + rules = append(rules, security.ColumnSecurity{ + Schema: schema, + Tablename: table, + Path: parts[2:], + Accesstype: rec.Accesstype, + }) + } + return rules, nil } ``` ### Static Row Security ```go -func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { - templates := map[string]string{ - "orders": "user_id = {UserID}", - "documents": "user_id = {UserID} OR is_public = true", +type ConfigRowSecurityProvider struct { + templates map[string]string + blocked map[string]bool +} + +func NewConfigRowSecurityProvider(templates map[string]string, blocked map[string]bool) *ConfigRowSecurityProvider { + return &ConfigRowSecurityProvider{templates: templates, blocked: blocked} +} + +func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) { + key := fmt.Sprintf("%s.%s", schema, table) + + if p.blocked[key] { + return security.RowSecurity{HasBlock: true}, nil } + return security.RowSecurity{ - Template: templates[table], + Schema: schema, + Tablename: table, + UserID: userID, + Template: p.templates[key], }, nil } ``` @@ -182,19 +377,22 @@ func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) ## Testing ```go -// Test auth callback +// Test Authenticator +auth := security.NewHeaderAuthenticator() req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-User-ID", "123") -userID, roles, err := myAuth(req) -assert.Equal(t, 123, userID) +userCtx, err := auth.Authenticate(req) +assert.Equal(t, 123, userCtx.UserID) -// Test column security callback -rules, err := myColSec(123, "public", "employees") -assert.Equal(t, "mask", rules[0].Accesstype) +// Test ColumnSecurityProvider +colSec := security.NewConfigColumnSecurityProvider(rules) +cols, err := colSec.GetColumnSecurity(context.Background(), 123, "public", "employees") +assert.Equal(t, "mask", cols[0].Accesstype) -// Test row security callback -rowSec, err := myRowSec(123, "public", "orders") -assert.Equal(t, "user_id = {UserID}", rowSec.Template) +// Test RowSecurityProvider +rowSec := security.NewConfigRowSecurityProvider(templates, blocked) +row, err := rowSec.GetRowSecurity(context.Background(), 123, "public", "orders") +assert.Equal(t, "user_id = {UserID}", row.Template) ``` --- @@ -204,13 +402,13 @@ assert.Equal(t, "user_id = {UserID}", rowSec.Template) ``` HTTP Request ↓ -AuthMiddleware → calls AuthenticateCallback - ↓ (adds userID to context) -SetSecurityMiddleware → adds GlobalSecurity to context +NewAuthMiddleware → calls provider.Authenticate() + ↓ (adds UserContext to context) +SetSecurityMiddleware → adds SecurityList to context ↓ Handler.Handle() ↓ -BeforeRead Hook → calls LoadColumnSecurityCallback + LoadRowSecurityCallback +BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity() ↓ BeforeScan Hook → applies row security (WHERE clause) ↓ @@ -228,10 +426,13 @@ HTTP Response ### Role-Based Security ```go -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { - if isAdmin(userID) { +func (p *MyColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + userCtx, _ := security.GetUserContext(ctx) + + if contains(userCtx.Roles, "admin") { return []security.ColumnSecurity{}, nil // No restrictions } + return loadRestrictions(userID, schema, table), nil } ``` @@ -239,7 +440,7 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er ### Tenant Isolation ```go -func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { +func (p *MyRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) { tenantID := getUserTenant(userID) return security.RowSecurity{ Template: fmt.Sprintf("tenant_id = %d", tenantID), @@ -247,19 +448,26 @@ func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) } ``` -### Caching +### Caching with Decorator ```go -var cache = make(map[string][]security.ColumnSecurity) +type CachedColumnSecurityProvider struct { + inner security.ColumnSecurityProvider + cache *cache.Cache +} -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { +func (p *CachedColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { key := fmt.Sprintf("%d:%s.%s", userID, schema, table) - if cached, ok := cache[key]; ok { - return cached, nil + + if cached, found := p.cache.Get(key); found { + return cached.([]security.ColumnSecurity), nil } - rules := loadFromDB(userID, schema, table) - cache[key] = rules - return rules, nil + + rules, err := p.inner.GetColumnSecurity(ctx, userID, schema, table) + if err == nil { + p.cache.Set(key, rules, cache.DefaultExpiration) + } + return rules, err } ``` @@ -268,21 +476,20 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er ## Error Handling ```go -// Setup will fail if callbacks not configured -if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { - log.Fatal("Security setup failed:", err) -} +// Panic if provider is nil +provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) +// panics if any parameter is nil -// Auth middleware rejects if callback returns error -func myAuth(r *http.Request) (int, string, error) { +// Auth middleware returns 401 if Authenticate fails +func (a *MyAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { if invalid { - return 0, "", fmt.Errorf("invalid credentials") // Returns HTTP 401 + return nil, fmt.Errorf("invalid credentials") // Returns HTTP 401 } - return userID, roles, nil + return &security.UserContext{UserID: userID}, nil } // Security loading can fail gracefully -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { +func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { rules, err := db.Load(...) if err != nil { log.Printf("Failed to load security: %v", err) @@ -294,6 +501,45 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er --- +## Login/Logout Endpoints + +```go +func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) { + // Login + router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) { + var req security.LoginRequest + json.NewDecoder(r.Body).Decode(&req) + + resp, err := securityList.Provider().Login(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + json.NewEncoder(w).Encode(resp) + }).Methods("POST") + + // Logout + router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + userID, _ := security.GetUserID(r.Context()) + + err := securityList.Provider().Logout(r.Context(), security.LogoutRequest{ + Token: token, + UserID: userID, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + }).Methods("POST") +} +``` + +--- + ## Debugging ```go @@ -301,15 +547,15 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er import "github.com/bitechdev/GoCore/pkg/cfg" cfg.SetLogLevel("DEBUG") -// Log in callbacks -func myAuth(r *http.Request) (int, string, error) { +// Log in provider methods +func (a *MyAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { token := r.Header.Get("Authorization") log.Printf("Auth: token=%s", token) // ... } -// Check if callbacks are called -func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { +// Check if methods are called +func (p *MyColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { log.Printf("Loading column security: user=%d, schema=%s, table=%s", userID, schema, table) // ... } @@ -323,6 +569,7 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er package main import ( + "context" "fmt" "net/http" "strconv" @@ -331,29 +578,42 @@ import ( "github.com/gorilla/mux" ) +// Simple all-in-one provider +type SimpleProvider struct{} + +func (p *SimpleProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +func (p *SimpleProvider) Logout(ctx context.Context, req security.LogoutRequest) error { + return nil +} + +func (p *SimpleProvider) Authenticate(r *http.Request) (*security.UserContext, error) { + id, _ := strconv.Atoi(r.Header.Get("X-User-ID")) + return &security.UserContext{UserID: id}, nil +} + +func (p *SimpleProvider) GetColumnSecurity(ctx context.Context, u int, s, t string) ([]security.ColumnSecurity, error) { + return []security.ColumnSecurity{}, nil +} + +func (p *SimpleProvider) GetRowSecurity(ctx context.Context, u int, s, t string) (security.RowSecurity, error) { + return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil +} + func main() { handler := restheadspec.NewHandlerWithGORM(db) - // Configure callbacks - security.GlobalSecurity.AuthenticateCallback = func(r *http.Request) (int, string, error) { - id, _ := strconv.Atoi(r.Header.Get("X-User-ID")) - return id, "", nil - } - security.GlobalSecurity.LoadColumnSecurityCallback = func(u int, s, t string) ([]security.ColumnSecurity, error) { - return []security.ColumnSecurity{}, nil - } - security.GlobalSecurity.LoadRowSecurityCallback = func(u int, s, t string) (security.RowSecurity, error) { - return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil - } + // Setup security + provider := &SimpleProvider{} + securityList := security.SetupSecurityProvider(handler, provider) - // Setup - security.SetupSecurityProvider(handler, &security.GlobalSecurity) - - // Middleware + // Apply middleware router := mux.NewRouter() restheadspec.SetupMuxRoutes(router, handler) - router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) - router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) + router.Use(security.NewAuthMiddleware(securityList)) + router.Use(security.SetSecurityMiddleware(securityList)) http.ListenAndServe(":8080", router) } @@ -361,15 +621,32 @@ func main() { --- +## Context Helpers + +```go +// Get full user context +userCtx, ok := security.GetUserContext(ctx) + +// Get individual fields +userID, ok := security.GetUserID(ctx) +userName, ok := security.GetUserName(ctx) +userLevel, ok := security.GetUserLevel(ctx) +sessionID, ok := security.GetSessionID(ctx) +remoteID, ok := security.GetRemoteID(ctx) +roles, ok := security.GetUserRoles(ctx) +email, ok := security.GetUserEmail(ctx) +``` + +--- + ## Resources | File | Description | |------|-------------| -| `CALLBACKS_GUIDE.md` | **Start here** - Complete implementation guide | -| `callbacks_example.go` | 7 working examples to copy | -| `CALLBACKS_SUMMARY.md` | Architecture overview | -| `README.md` | Full documentation | -| `setup_example.go` | Integration examples | +| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide | +| `examples.go` | Working provider implementations to copy | +| `setup_example.go` | 6 complete integration examples | +| `README.md` | Architecture overview and migration guide | --- @@ -377,22 +654,22 @@ func main() { ```go // ===== REQUIRED SETUP ===== -security.GlobalSecurity.AuthenticateCallback = myAuthFunc -security.GlobalSecurity.LoadColumnSecurityCallback = myColFunc -security.GlobalSecurity.LoadRowSecurityCallback = myRowFunc -security.SetupSecurityProvider(handler, &security.GlobalSecurity) +auth := security.NewJWTAuthenticator("secret", db) +colSec := security.NewDatabaseColumnSecurityProvider(db) +rowSec := security.NewDatabaseRowSecurityProvider(db) +provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) +securityList := security.SetupSecurityProvider(handler, provider) -// ===== CALLBACK SIGNATURES ===== -func(r *http.Request) (int, string, error) // Auth -func(int, string, string) ([]security.ColumnSecurity, error) // Column -func(int, string, string) (security.RowSecurity, error) // Row +// ===== INTERFACE METHODS ===== +Authenticate(r *http.Request) (*UserContext, error) +Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) +Logout(ctx context.Context, req LogoutRequest) error +GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) +GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) // ===== QUICK EXAMPLES ===== // Header auth -func(r *http.Request) (int, string, error) { - id, _ := strconv.Atoi(r.Header.Get("X-User-ID")) - return id, "", nil -} +&UserContext{UserID: 123, UserName: "john"} // Mask SSN {Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5} diff --git a/pkg/security/README.md b/pkg/security/README.md new file mode 100644 index 0000000..054a878 --- /dev/null +++ b/pkg/security/README.md @@ -0,0 +1,634 @@ +# ResolveSpec Security Provider + +Type-safe, composable security system for ResolveSpec with support for authentication, column-level security (masking), and row-level security (filtering). + +## Features + +- ✅ **Interface-Based** - Type-safe providers instead of callbacks +- ✅ **Login/Logout Support** - Built-in authentication lifecycle +- ✅ **Composable** - Mix and match different providers +- ✅ **No Global State** - Each handler has its own security configuration +- ✅ **Testable** - Easy to mock and test +- ✅ **Extensible** - Implement custom providers for your needs +- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability + +## Stored Procedure Architecture + +**All database-backed security providers use PostgreSQL stored procedures exclusively.** No raw SQL queries are executed from Go code. + +### Benefits + +- **Security**: Database logic is centralized and protected +- **Maintainability**: Update database logic without recompiling Go code +- **Performance**: Stored procedures are pre-compiled and optimized +- **Testability**: Test database logic independently +- **Consistency**: Standardized `resolvespec_*` naming convention + +### Available Stored Procedures + +| Procedure | Purpose | Used By | +|-----------|---------|---------| +| `resolvespec_login` | Session-based login | DatabaseAuthenticator | +| `resolvespec_logout` | Session invalidation | DatabaseAuthenticator | +| `resolvespec_session` | Session validation | DatabaseAuthenticator | +| `resolvespec_session_update` | Update session activity | DatabaseAuthenticator | +| `resolvespec_refresh_token` | Token refresh | DatabaseAuthenticator | +| `resolvespec_jwt_login` | JWT user validation | JWTAuthenticator | +| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator | +| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider | +| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider | + +See `database_schema.sql` for complete stored procedure definitions and examples. + +## Quick Start + +```go +import ( + "github.com/bitechdev/ResolveSpec/pkg/security" + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" +) + +// 1. Create security providers +auth := security.NewJWTAuthenticator("your-secret-key", db) +colSec := security.NewDatabaseColumnSecurityProvider(db) +rowSec := security.NewDatabaseRowSecurityProvider(db) + +// 2. Combine providers +provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) + +// 3. Setup security +handler := restheadspec.NewHandlerWithGORM(db) +securityList := security.SetupSecurityProvider(handler, provider) + +// 4. Apply middleware +router := mux.NewRouter() +restheadspec.SetupMuxRoutes(router, handler) +router.Use(security.NewAuthMiddleware(securityList)) +router.Use(security.SetSecurityMiddleware(securityList)) +``` + +## Architecture + +### Core Interfaces + +The security system is built on three main interfaces: + +#### 1. Authenticator +Handles user authentication lifecycle: + +```go +type Authenticator interface { + Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) + Logout(ctx context.Context, req LogoutRequest) error + Authenticate(r *http.Request) (*UserContext, error) +} +``` + +#### 2. ColumnSecurityProvider +Manages column-level security (masking/hiding): + +```go +type ColumnSecurityProvider interface { + GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) +} +``` + +#### 3. RowSecurityProvider +Manages row-level security (WHERE clause filtering): + +```go +type RowSecurityProvider interface { + GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) +} +``` + +### SecurityProvider +The main interface that combines all three: + +```go +type SecurityProvider interface { + Authenticator + ColumnSecurityProvider + RowSecurityProvider +} +``` + +### UserContext +Enhanced user context with complete user information: + +```go +type UserContext struct { + UserID int // User's unique ID + UserName string // Username + UserLevel int // User privilege level + SessionID string // Current session ID + RemoteID string // Remote system ID + Roles []string // User roles + Email string // User email + Claims map[string]any // Additional metadata +} +``` + +## Available Implementations + +### Authenticators + +**HeaderAuthenticator** - Simple header-based authentication: +```go +auth := security.NewHeaderAuthenticator() +// Expects: X-User-ID, X-User-Name, X-User-Level, etc. +``` + +**DatabaseAuthenticator** - Database session-based authentication (Recommended): +```go +auth := security.NewDatabaseAuthenticator(db) +// Supports: Login, Logout, Session management, Token refresh +// All operations use stored procedures: resolvespec_login, resolvespec_logout, +// resolvespec_session, resolvespec_session_update, resolvespec_refresh_token +// Requires: users and user_sessions tables + stored procedures (see database_schema.sql) +``` + +**JWTAuthenticator** - JWT token authentication with login/logout: +```go +auth := security.NewJWTAuthenticator("secret-key", db) +// Supports: Login, Logout, JWT token validation +// All operations use stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout +// Note: Requires JWT library installation for token signing/verification +``` + +### Column Security Providers + +**DatabaseColumnSecurityProvider** - Loads rules from database: +```go +colSec := security.NewDatabaseColumnSecurityProvider(db) +// Uses stored procedure: resolvespec_column_security +// Queries core.secaccess and core.hub_link tables +``` + +**ConfigColumnSecurityProvider** - Static configuration: +```go +rules := map[string][]security.ColumnSecurity{ + "public.employees": { + {Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}, + }, +} +colSec := security.NewConfigColumnSecurityProvider(rules) +``` + +### Row Security Providers + +**DatabaseRowSecurityProvider** - Loads filters from database: +```go +rowSec := security.NewDatabaseRowSecurityProvider(db) +// Uses stored procedure: resolvespec_row_security +``` + +**ConfigRowSecurityProvider** - Static templates: +```go +templates := map[string]string{ + "public.orders": "user_id = {UserID}", +} +blocked := map[string]bool{ + "public.admin_logs": true, +} +rowSec := security.NewConfigRowSecurityProvider(templates, blocked) +``` + +## Usage Examples + +### Example 1: Complete Database-Backed Security with Sessions + +```go +func main() { + db := setupDatabase() + + // Run migrations (see database_schema.sql) + // db.Exec("CREATE TABLE users ...") + // db.Exec("CREATE TABLE user_sessions ...") + + handler := restheadspec.NewHandlerWithGORM(db) + + // Create providers + auth := security.NewDatabaseAuthenticator(db) // Session-based auth + colSec := security.NewDatabaseColumnSecurityProvider(db) + rowSec := security.NewDatabaseRowSecurityProvider(db) + + // Combine + provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := security.SetupSecurityProvider(handler, provider) + + // Setup routes + router := mux.NewRouter() + + // Add auth endpoints + router.HandleFunc("/auth/login", handleLogin(securityList)).Methods("POST") + router.HandleFunc("/auth/logout", handleLogout(securityList)).Methods("POST") + router.HandleFunc("/auth/refresh", handleRefresh(securityList)).Methods("POST") + + // Setup API with security + apiRouter := router.PathPrefix("/api").Subrouter() + restheadspec.SetupMuxRoutes(apiRouter, handler) + apiRouter.Use(security.NewAuthMiddleware(securityList)) + apiRouter.Use(security.SetSecurityMiddleware(securityList)) + + http.ListenAndServe(":8080", router) +} + +func handleLogin(securityList *security.SecurityList) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req security.LoginRequest + json.NewDecoder(r.Body).Decode(&req) + + // Add client info to claims + req.Claims = map[string]any{ + "ip_address": r.RemoteAddr, + "user_agent": r.UserAgent(), + } + + resp, err := securityList.Provider().Login(r.Context(), req) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + // Set session cookie (optional) + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: resp.Token, + Expires: time.Now().Add(24 * time.Hour), + HttpOnly: true, + Secure: true, // Use in production with HTTPS + SameSite: http.SameSiteStrictMode, + }) + + json.NewEncoder(w).Encode(resp) + } +} + +func handleRefresh(securityList *security.SecurityList) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("X-Refresh-Token") + + if refreshable, ok := securityList.Provider().(security.Refreshable); ok { + resp, err := refreshable.RefreshToken(r.Context(), token) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + json.NewEncoder(w).Encode(resp) + } else { + http.Error(w, "Refresh not supported", http.StatusNotImplemented) + } + } +} +``` + +### Example 2: Config-Based Security (No Database) + +```go +func main() { + db := setupDatabase() + handler := restheadspec.NewHandlerWithGORM(db) + + // Static column security rules + columnRules := map[string][]security.ColumnSecurity{ + "public.employees": { + {Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}, + {Path: []string{"salary"}, Accesstype: "hide"}, + }, + } + + // Static row security templates + rowTemplates := map[string]string{ + "public.orders": "user_id = {UserID}", + } + + // Create providers + auth := security.NewHeaderAuthenticator() + colSec := security.NewConfigColumnSecurityProvider(columnRules) + rowSec := security.NewConfigRowSecurityProvider(rowTemplates, nil) + + provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := security.SetupSecurityProvider(handler, provider) + + // Setup routes... +} +``` + +### Example 3: Custom Provider + +Implement your own provider for complete control: + +```go +type MySecurityProvider struct { + db *gorm.DB +} + +func (p *MySecurityProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) { + // Your custom login logic +} + +func (p *MySecurityProvider) Logout(ctx context.Context, req security.LogoutRequest) error { + // Your custom logout logic +} + +func (p *MySecurityProvider) Authenticate(r *http.Request) (*security.UserContext, error) { + // Your custom authentication logic +} + +func (p *MySecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + // Your custom column security logic +} + +func (p *MySecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) { + // Your custom row security logic +} + +// Use it +provider := &MySecurityProvider{db: db} +securityList := security.SetupSecurityProvider(handler, provider) +``` + +## Security Features + +### Column Security (Masking/Hiding) + +**Mask SSN (show last 4 digits):** +```go +{ + Path: []string{"ssn"}, + Accesstype: "mask", + MaskStart: 5, + MaskChar: "*", +} +// "123-45-6789" → "*****6789" +``` + +**Hide entire field:** +```go +{ + Path: []string{"salary"}, + Accesstype: "hide", +} +// Field returns 0 or empty +``` + +**Nested JSON field masking:** +```go +{ + Path: []string{"address", "street"}, + Accesstype: "mask", + MaskStart: 10, +} +``` + +### Row Security (Filtering) + +**User isolation:** +```go +{ + Template: "user_id = {UserID}", +} +// Users only see their own records +``` + +**Tenant isolation:** +```go +{ + Template: "tenant_id = {TenantID} AND user_id = {UserID}", +} +``` + +**Block all access:** +```go +{ + HasBlock: true, +} +// Completely blocks access to the table +``` + +**Template variables:** +- `{UserID}` - Current user's ID +- `{PrimaryKeyName}` - Primary key column +- `{TableName}` - Table name +- `{SchemaName}` - Schema name + +## Request Flow + +``` +HTTP Request + ↓ +NewAuthMiddleware + ├─ Calls provider.Authenticate(request) + └─ Adds UserContext to context + ↓ +SetSecurityMiddleware + └─ Adds SecurityList to context + ↓ +Handler.Handle() + ↓ +BeforeRead Hook + ├─ Calls provider.GetColumnSecurity() + └─ Calls provider.GetRowSecurity() + ↓ +BeforeScan Hook + └─ Applies row security (adds WHERE clause) + ↓ +Database Query (with security filters) + ↓ +AfterRead Hook + └─ Applies column security (masks/hides fields) + ↓ +HTTP Response (secured data) +``` + +## Testing + +The interface-based design makes testing straightforward: + +```go +// Mock authenticator for tests +type MockAuthenticator struct { + UserToReturn *security.UserContext + ErrorToReturn error +} + +func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { + return m.UserToReturn, m.ErrorToReturn +} + +// Use in tests +func TestMyHandler(t *testing.T) { + mockAuth := &MockAuthenticator{ + UserToReturn: &security.UserContext{UserID: 123}, + } + + provider := security.NewCompositeSecurityProvider( + mockAuth, + &MockColumnSecurity{}, + &MockRowSecurity{}, + ) + + securityList := security.SetupSecurityProvider(handler, provider) + // ... test your handler +} +``` + +## Migration from Callbacks + +If you're upgrading from the old callback-based system: + +**Old:** +```go +security.GlobalSecurity.AuthenticateCallback = myAuthFunc +security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc +security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc +security.SetupSecurityProvider(handler, &security.GlobalSecurity) +``` + +**New:** +```go +// Wrap your functions in a provider +type MyProvider struct{} + +func (p *MyProvider) Authenticate(r *http.Request) (*security.UserContext, error) { + userID, roles, err := myAuthFunc(r) + return &security.UserContext{UserID: userID, Roles: strings.Split(roles, ",")}, err +} + +func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + return myColSecFunc(userID, schema, table) +} + +func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) { + return myRowSecFunc(userID, schema, table) +} + +func (p *MyProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +func (p *MyProvider) Logout(ctx context.Context, req security.LogoutRequest) error { + return nil +} + +// Use it +provider := &MyProvider{} +securityList := security.SetupSecurityProvider(handler, provider) +``` + +## Documentation + +| File | Description | +|------|-------------| +| **QUICK_REFERENCE.md** | Quick reference guide with examples | +| **INTERFACE_GUIDE.md** | Complete implementation guide | +| **examples.go** | Working provider implementations | +| **setup_example.go** | 6 complete integration examples | + +## API Reference + +### Context Helpers + +Get user information from request context: + +```go +userCtx, ok := security.GetUserContext(ctx) +userID, ok := security.GetUserID(ctx) +userName, ok := security.GetUserName(ctx) +userLevel, ok := security.GetUserLevel(ctx) +sessionID, ok := security.GetSessionID(ctx) +remoteID, ok := security.GetRemoteID(ctx) +roles, ok := security.GetUserRoles(ctx) +email, ok := security.GetUserEmail(ctx) +``` + +### Optional Interfaces + +Implement these for additional features: + +**Refreshable** - Token refresh support: +```go +type Refreshable interface { + RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) +} +``` + +**Validatable** - Token validation: +```go +type Validatable interface { + ValidateToken(ctx context.Context, token string) (bool, error) +} +``` + +**Cacheable** - Cache management: +```go +type Cacheable interface { + ClearCache(ctx context.Context, userID int, schema, table string) error +} +``` + +## Benefits Over Callbacks + +| Feature | Old (Callbacks) | New (Interfaces) | +|---------|----------------|------------------| +| Type Safety | ❌ Callbacks can be nil | ✅ Compile-time verification | +| Global State | ❌ GlobalSecurity variable | ✅ Dependency injection | +| Testability | ⚠️ Need to set globals | ✅ Easy to mock | +| Composability | ❌ Single provider only | ✅ Mix and match | +| Login/Logout | ❌ Not supported | ✅ Built-in | +| Extensibility | ⚠️ Limited | ✅ Optional interfaces | + +## Common Patterns + +### Caching Security Rules + +```go +type CachedProvider struct { + inner security.ColumnSecurityProvider + cache *cache.Cache +} + +func (p *CachedProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + key := fmt.Sprintf("%d:%s.%s", userID, schema, table) + if cached, found := p.cache.Get(key); found { + return cached.([]security.ColumnSecurity), nil + } + + rules, err := p.inner.GetColumnSecurity(ctx, userID, schema, table) + if err == nil { + p.cache.Set(key, rules, cache.DefaultExpiration) + } + return rules, err +} +``` + +### Role-Based Security + +```go +func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) { + userCtx, _ := security.GetUserContext(ctx) + + if contains(userCtx.Roles, "admin") { + return []security.ColumnSecurity{}, nil // No restrictions + } + + return loadRestrictionsForUser(userID, schema, table), nil +} +``` + +### Multi-Tenant Isolation + +```go +func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) { + tenantID := getUserTenant(userID) + + return security.RowSecurity{ + Template: fmt.Sprintf("tenant_id = %d AND user_id = {UserID}", tenantID), + }, nil +} +``` + +## License + +Part of the ResolveSpec project. diff --git a/pkg/security/callbacks_example.go b/pkg/security/callbacks_example.go deleted file mode 100644 index 1935652..0000000 --- a/pkg/security/callbacks_example.go +++ /dev/null @@ -1,414 +0,0 @@ -package security - -import ( - "fmt" - "net/http" - "strconv" - "strings" -) - -// This file provides example implementations of the required security callbacks. -// Copy these functions and modify them to match your authentication and database schema. - -// ============================================================================= -// EXAMPLE 1: Simple Header-Based Authentication -// ============================================================================= - -// ExampleAuthenticateFromHeader extracts user ID from X-User-ID header -func ExampleAuthenticateFromHeader(r *http.Request) (userID int, roles string, err error) { - userIDStr := r.Header.Get("X-User-ID") - if userIDStr == "" { - return 0, "", fmt.Errorf("X-User-ID header not provided") - } - - userID, err = strconv.Atoi(userIDStr) - if err != nil { - return 0, "", fmt.Errorf("invalid user ID format: %v", err) - } - - // Optionally extract roles - roles = r.Header.Get("X-User-Roles") // comma-separated: "admin,manager" - - return userID, roles, nil -} - -// ============================================================================= -// EXAMPLE 2: JWT Token Authentication -// ============================================================================= - -// ExampleAuthenticateFromJWT parses a JWT token and extracts user info -// You'll need to import a JWT library like github.com/golang-jwt/jwt/v5 -func ExampleAuthenticateFromJWT(r *http.Request) (userID int, roles string, err error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return 0, "", fmt.Errorf("authorization header not provided") - } - - // Extract Bearer token - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - if tokenString == authHeader { - return 0, "", fmt.Errorf("invalid authorization header format") - } - - // TODO: Parse and validate JWT token - // Example using github.com/golang-jwt/jwt/v5: - // - // token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // return []byte(os.Getenv("JWT_SECRET")), nil - // }) - // - // if err != nil || !token.Valid { - // return 0, "", fmt.Errorf("invalid token: %v", err) - // } - // - // claims := token.Claims.(jwt.MapClaims) - // userID = int(claims["user_id"].(float64)) - // roles = claims["roles"].(string) - - return 0, "", fmt.Errorf("JWT parsing not implemented - see example above") -} - -// ============================================================================= -// EXAMPLE 3: Session Cookie Authentication -// ============================================================================= - -// ExampleAuthenticateFromSession validates a session cookie -func ExampleAuthenticateFromSession(r *http.Request) (userID int, roles string, err error) { - sessionCookie, err := r.Cookie("session_id") - if err != nil { - return 0, "", fmt.Errorf("session cookie not found") - } - - // TODO: Validate session against your session store (Redis, database, etc.) - // Example: - // - // session, err := sessionStore.Get(sessionCookie.Value) - // if err != nil { - // return 0, "", fmt.Errorf("invalid session") - // } - // - // userID = session.UserID - // roles = session.Roles - - _ = sessionCookie // Suppress unused warning until implemented - return 0, "", fmt.Errorf("session validation not implemented - see example above") -} - -// ============================================================================= -// EXAMPLE 4: Column Security - Database Implementation -// ============================================================================= - -// ExampleLoadColumnSecurityFromDatabase loads column security rules from database -// This implementation assumes the following database schema: -// -// CREATE TABLE core.secacces ( -// rid_secacces SERIAL PRIMARY KEY, -// rid_hub INTEGER, -// control TEXT, -- Format: "schema.table.column" -// accesstype TEXT, -- "mask" or "hide" -// jsonvalue JSONB -- Masking configuration -// ); -// -// CREATE TABLE core.hub_link ( -// rid_hub_parent INTEGER, -- Security group ID -// rid_hub_child INTEGER, -- User ID -// parent_hubtype TEXT -- 'secgroup' -// ); -func ExampleLoadColumnSecurityFromDatabase(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) { - colSecList := make([]ColumnSecurity, 0) - - // getExtraFilters := func(pStr string) map[string]string { - // mp := make(map[string]string, 0) - // for i, val := range strings.Split(pStr, ",") { - // if i <= 1 { - // continue - // } - // vals := strings.Split(val, ":") - // if len(vals) > 1 { - // mp[vals[0]] = vals[1] - // } - // } - // return mp - // } - - // rows, err := DBM.DBConn.Raw(fmt.Sprintf(` - // SELECT a.rid_secacces, a.control, a.accesstype, a.jsonvalue - // FROM core.secacces a - // WHERE a.rid_hub IN ( - // SELECT l.rid_hub_parent - // FROM core.hub_link l - // WHERE l.parent_hubtype = 'secgroup' - // AND l.rid_hub_child = ? - // ) - // AND control ILIKE '%s.%s%%' - // `, pSchema, pTablename), pUserID).Rows() - - // defer func() { - // if rows != nil { - // rows.Close() - // } - // }() - - // if err != nil { - // return colSecList, fmt.Errorf("failed to fetch column security from SQL: %v", err) - // } - - // for rows.Next() { - // var rid int - // var jsondata []byte - // var control, accesstype string - - // err = rows.Scan(&rid, &control, &accesstype, &jsondata) - // if err != nil { - // return colSecList, fmt.Errorf("failed to scan column security: %v", err) - // } - - // parts := strings.Split(control, ",") - // ids := strings.Split(parts[0], ".") - // if len(ids) < 3 { - // continue - // } - - // jsonvalue := make(map[string]interface{}) - // if len(jsondata) > 1 { - // err = json.Unmarshal(jsondata, &jsonvalue) - // if err != nil { - // logger.Error("Failed to parse json: %v", err) - // } - // } - - // colsec := ColumnSecurity{ - // Schema: pSchema, - // Tablename: pTablename, - // UserID: pUserID, - // Path: ids[2:], - // ExtraFilters: getExtraFilters(control), - // Accesstype: accesstype, - // Control: control, - // ID: int(rid), - // } - - // // Parse masking configuration from JSON - // if v, ok := jsonvalue["start"]; ok { - // if value, ok := v.(float64); ok { - // colsec.MaskStart = int(value) - // } - // } - - // if v, ok := jsonvalue["end"]; ok { - // if value, ok := v.(float64); ok { - // colsec.MaskEnd = int(value) - // } - // } - - // if v, ok := jsonvalue["invert"]; ok { - // if value, ok := v.(bool); ok { - // colsec.MaskInvert = value - // } - // } - - // if v, ok := jsonvalue["char"]; ok { - // if value, ok := v.(string); ok { - // colsec.MaskChar = value - // } - // } - - // colSecList = append(colSecList, colsec) - // } - - return colSecList, nil -} - -// ============================================================================= -// EXAMPLE 5: Column Security - In-Memory/Static Configuration -// ============================================================================= - -// ExampleLoadColumnSecurityFromConfig loads column security from static config -func ExampleLoadColumnSecurityFromConfig(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) { - // Example: Define security rules in code or load from config file - securityRules := map[string][]ColumnSecurity{ - "public.employees": { - { - Schema: "public", - Tablename: "employees", - Path: []string{"ssn"}, - Accesstype: "mask", - MaskStart: 5, - MaskEnd: 0, - MaskChar: "*", - }, - { - Schema: "public", - Tablename: "employees", - Path: []string{"salary"}, - Accesstype: "hide", - }, - }, - "public.customers": { - { - Schema: "public", - Tablename: "customers", - Path: []string{"credit_card"}, - Accesstype: "mask", - MaskStart: 12, - MaskEnd: 0, - MaskChar: "*", - }, - }, - } - - key := fmt.Sprintf("%s.%s", pSchema, pTablename) - rules, ok := securityRules[key] - if !ok { - return []ColumnSecurity{}, nil // No rules for this table - } - - // Filter by user ID if needed - // For this example, all rules apply to all users - return rules, nil -} - -// ============================================================================= -// EXAMPLE 6: Row Security - Database Implementation -// ============================================================================= - -// ExampleLoadRowSecurityFromDatabase loads row security rules from database -// This implementation assumes a PostgreSQL function: -// -// CREATE FUNCTION core.api_sec_rowtemplate( -// p_schema TEXT, -// p_table TEXT, -// p_userid INTEGER -// ) RETURNS TABLE ( -// p_retval INTEGER, -// p_errmsg TEXT, -// p_template TEXT, -// p_block BOOLEAN -// ); -func ExampleLoadRowSecurityFromDatabase(pUserID int, pSchema, pTablename string) (RowSecurity, error) { - record := RowSecurity{ - Schema: pSchema, - Tablename: pTablename, - UserID: pUserID, - } - - // rows, err := DBM.DBConn.Raw(` - // SELECT r.p_retval, r.p_errmsg, r.p_template, r.p_block - // FROM core.api_sec_rowtemplate(?, ?, ?) r - // `, pSchema, pTablename, pUserID).Rows() - - // defer func() { - // if rows != nil { - // rows.Close() - // } - // }() - - // if err != nil { - // return record, fmt.Errorf("failed to fetch row security from SQL: %v", err) - // } - - // for rows.Next() { - // var retval int - // var errmsg string - - // err = rows.Scan(&retval, &errmsg, &record.Template, &record.HasBlock) - // if err != nil { - // return record, fmt.Errorf("failed to scan row security: %v", err) - // } - - // if retval != 0 { - // return RowSecurity{}, fmt.Errorf("api_sec_rowtemplate error: %s", errmsg) - // } - // } - - return record, nil -} - -// ============================================================================= -// EXAMPLE 7: Row Security - Static Configuration -// ============================================================================= - -// ExampleLoadRowSecurityFromConfig loads row security from static config -func ExampleLoadRowSecurityFromConfig(pUserID int, pSchema, pTablename string) (RowSecurity, error) { - // Define row security templates based on entity - templates := map[string]string{ - "public.orders": "user_id = {UserID}", // Users see only their orders - "public.documents": "user_id = {UserID} OR is_public = true", // Users see their docs + public docs - "public.employees": "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", // Complex filter - } - - // Define blocked entities (no access at all) - blockedEntities := map[string][]int{ - "public.admin_logs": {}, // All users blocked (empty list = block all) - "public.audit_logs": {1, 2, 3}, // Block users 1, 2, 3 - } - - key := fmt.Sprintf("%s.%s", pSchema, pTablename) - - // Check if entity is blocked for this user - if blockedUsers, ok := blockedEntities[key]; ok { - if len(blockedUsers) == 0 { - // Block all users - return RowSecurity{ - Schema: pSchema, - Tablename: pTablename, - UserID: pUserID, - HasBlock: true, - }, nil - } - // Check if specific user is blocked - for _, blockedUserID := range blockedUsers { - if blockedUserID == pUserID { - return RowSecurity{ - Schema: pSchema, - Tablename: pTablename, - UserID: pUserID, - HasBlock: true, - }, nil - } - } - } - - // Get template for this entity - template, ok := templates[key] - if !ok { - // No row security defined - allow all rows - return RowSecurity{ - Schema: pSchema, - Tablename: pTablename, - UserID: pUserID, - Template: "", - HasBlock: false, - }, nil - } - - return RowSecurity{ - Schema: pSchema, - Tablename: pTablename, - UserID: pUserID, - Template: template, - HasBlock: false, - }, nil -} - -// ============================================================================= -// SETUP HELPER: Configure All Callbacks -// ============================================================================= - -// SetupCallbacksExample shows how to configure all callbacks -func SetupCallbacksExample() { - // Option 1: Use database-backed security (production) - GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT - GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase - GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase - - // Option 2: Use static configuration (development/testing) - // GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader - // GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig - // GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig - - // Option 3: Mix and match - // GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT - // GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig - // GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase -} diff --git a/pkg/security/composite.go b/pkg/security/composite.go new file mode 100644 index 0000000..8eafcd2 --- /dev/null +++ b/pkg/security/composite.go @@ -0,0 +1,105 @@ +package security + +import ( + "context" + "fmt" + "net/http" +) + +// CompositeSecurityProvider combines multiple security providers +// Allows separating authentication, column security, and row security concerns +type CompositeSecurityProvider struct { + auth Authenticator + colSec ColumnSecurityProvider + rowSec RowSecurityProvider +} + +// NewCompositeSecurityProvider creates a composite provider +// All parameters are required +func NewCompositeSecurityProvider( + auth Authenticator, + colSec ColumnSecurityProvider, + rowSec RowSecurityProvider, +) *CompositeSecurityProvider { + if auth == nil { + panic("authenticator cannot be nil") + } + if colSec == nil { + panic("column security provider cannot be nil") + } + if rowSec == nil { + panic("row security provider cannot be nil") + } + + return &CompositeSecurityProvider{ + auth: auth, + colSec: colSec, + rowSec: rowSec, + } +} + +// Login delegates to the authenticator +func (c *CompositeSecurityProvider) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + return c.auth.Login(ctx, req) +} + +// Logout delegates to the authenticator +func (c *CompositeSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error { + return c.auth.Logout(ctx, req) +} + +// Authenticate delegates to the authenticator +func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) { + return c.auth.Authenticate(r) +} + +// GetColumnSecurity delegates to the column security provider +func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { + return c.colSec.GetColumnSecurity(ctx, userID, schema, table) +} + +// GetRowSecurity delegates to the row security provider +func (c *CompositeSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { + return c.rowSec.GetRowSecurity(ctx, userID, schema, table) +} + +// Optional interface implementations (if wrapped providers support them) + +// RefreshToken implements Refreshable if the authenticator supports it +func (c *CompositeSecurityProvider) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { + if refreshable, ok := c.auth.(Refreshable); ok { + return refreshable.RefreshToken(ctx, refreshToken) + } + return nil, fmt.Errorf("authenticator does not support token refresh") +} + +// ValidateToken implements Validatable if the authenticator supports it +func (c *CompositeSecurityProvider) ValidateToken(ctx context.Context, token string) (bool, error) { + if validatable, ok := c.auth.(Validatable); ok { + return validatable.ValidateToken(ctx, token) + } + return false, fmt.Errorf("authenticator does not support token validation") +} + +// ClearCache implements Cacheable if any provider supports it +func (c *CompositeSecurityProvider) ClearCache(ctx context.Context, userID int, schema, table string) error { + var errs []error + + if cacheable, ok := c.colSec.(Cacheable); ok { + if err := cacheable.ClearCache(ctx, userID, schema, table); err != nil { + errs = append(errs, fmt.Errorf("column security cache clear failed: %w", err)) + } + } + + if cacheable, ok := c.rowSec.(Cacheable); ok { + if err := cacheable.ClearCache(ctx, userID, schema, table); err != nil { + errs = append(errs, fmt.Errorf("row security cache clear failed: %w", err)) + } + } + + if len(errs) > 0 { + return fmt.Errorf("cache clear errors: %v", errs) + } + + return nil +} diff --git a/pkg/security/database_schema.sql b/pkg/security/database_schema.sql new file mode 100644 index 0000000..a66a787 --- /dev/null +++ b/pkg/security/database_schema.sql @@ -0,0 +1,428 @@ +-- Database Schema for DatabaseAuthenticator +-- ============================================ + +-- Users table +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + email VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255) NOT NULL, -- bcrypt hashed password + user_level INTEGER DEFAULT 0, + roles VARCHAR(500), -- Comma-separated roles: "admin,manager,user" + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_login_at TIMESTAMP +); + +-- User sessions table for DatabaseAuthenticator +CREATE TABLE IF NOT EXISTS user_sessions ( + id SERIAL PRIMARY KEY, + session_token VARCHAR(500) NOT NULL UNIQUE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + ip_address VARCHAR(45), -- IPv4 or IPv6 + user_agent TEXT +); + +CREATE INDEX IF NOT EXISTS idx_session_token ON user_sessions(session_token); +CREATE INDEX IF NOT EXISTS idx_user_id ON user_sessions(user_id); +CREATE INDEX IF NOT EXISTS idx_expires_at ON user_sessions(expires_at); + +-- Optional: Token blacklist for logout tracking (useful for JWT too) +CREATE TABLE IF NOT EXISTS token_blacklist ( + id SERIAL PRIMARY KEY, + token VARCHAR(500) NOT NULL, + user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_token ON token_blacklist(token); +CREATE INDEX IF NOT EXISTS idx_blacklist_expires_at ON token_blacklist(expires_at); + +-- Example: Seed admin user (password should be hashed with bcrypt) +-- INSERT INTO users (username, email, password, user_level, roles, is_active) +-- VALUES ('admin', 'admin@example.com', '$2a$10$...', 10, 'admin,user', true); + +-- Cleanup expired sessions (run periodically) +-- DELETE FROM user_sessions WHERE expires_at < NOW(); + +-- Cleanup expired blacklisted tokens (run periodically) +-- DELETE FROM token_blacklist WHERE expires_at < NOW(); + +-- ============================================ +-- Stored Procedures for DatabaseAuthenticator +-- ============================================ + +-- 1. resolvespec_login - Authenticates user and creates session +-- Input: LoginRequest as jsonb {username: string, password: string, claims: object} +-- Output: p_success (bool), p_error (text), p_data (LoginResponse as jsonb) +CREATE OR REPLACE FUNCTION resolvespec_login(p_request jsonb) +RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$ +DECLARE + v_user_id INTEGER; + v_username TEXT; + v_email TEXT; + v_user_level INTEGER; + v_roles TEXT; + v_password_hash TEXT; + v_session_token TEXT; + v_expires_at TIMESTAMP; + v_ip_address TEXT; + v_user_agent TEXT; +BEGIN + -- Extract login request fields + v_username := p_request->>'username'; + v_ip_address := p_request->'claims'->>'ip_address'; + v_user_agent := p_request->'claims'->>'user_agent'; + + -- Validate user credentials + SELECT id, username, email, password, user_level, roles + INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles + FROM users + WHERE username = v_username AND is_active = true; + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb; + RETURN; + END IF; + + -- TODO: Verify password hash using pgcrypto extension + -- Enable pgcrypto: CREATE EXTENSION IF NOT EXISTS pgcrypto; + -- IF NOT (crypt(p_request->>'password', v_password_hash) = v_password_hash) THEN + -- RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb; + -- RETURN; + -- END IF; + + -- Generate session token + v_session_token := 'sess_' || encode(gen_random_bytes(32), 'hex') || '_' || extract(epoch from now())::bigint::text; + v_expires_at := now() + interval '24 hours'; + + -- Create session + INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at) + VALUES (v_session_token, v_user_id, v_expires_at, v_ip_address, v_user_agent, now()); + + -- Update last login time + UPDATE users SET last_login_at = now() WHERE id = v_user_id; + + -- Return success with LoginResponse + RETURN QUERY SELECT + true, + NULL::text, + jsonb_build_object( + 'token', v_session_token, + 'user', jsonb_build_object( + 'user_id', v_user_id, + 'user_name', v_username, + 'email', v_email, + 'user_level', v_user_level, + 'roles', string_to_array(COALESCE(v_roles, ''), ','), + 'session_id', v_session_token + ), + 'expires_in', 86400 -- 24 hours in seconds + ); +END; +$$ LANGUAGE plpgsql; + +-- 2. resolvespec_logout - Invalidates session +-- Input: LogoutRequest as jsonb {token: string, user_id: int} +-- Output: p_success (bool), p_error (text), p_data (jsonb) +CREATE OR REPLACE FUNCTION resolvespec_logout(p_request jsonb) +RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$ +DECLARE + v_token TEXT; + v_user_id INTEGER; + v_deleted INTEGER; +BEGIN + v_token := p_request->>'token'; + v_user_id := (p_request->>'user_id')::integer; + + -- Remove Bearer prefix if present + v_token := regexp_replace(v_token, '^Bearer ', '', 'i'); + + -- Delete the session + DELETE FROM user_sessions + WHERE session_token = v_token AND user_id = v_user_id; + + GET DIAGNOSTICS v_deleted = ROW_COUNT; + + IF v_deleted = 0 THEN + RETURN QUERY SELECT false, 'Session not found'::text, NULL::jsonb; + ELSE + RETURN QUERY SELECT true, NULL::text, jsonb_build_object('success', true); + END IF; +END; +$$ LANGUAGE plpgsql; + +-- 3. resolvespec_session - Validates session and returns user context +-- Input: sessionid (text), reference (text) +-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb) +CREATE OR REPLACE FUNCTION resolvespec_session(p_session_token text, p_reference text) +RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$ +DECLARE + v_user_id INTEGER; + v_username TEXT; + v_email TEXT; + v_user_level INTEGER; + v_roles TEXT; + v_session_id TEXT; +BEGIN + -- Query session and user data + SELECT + s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token + INTO + v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id + FROM user_sessions s + JOIN users u ON s.user_id = u.id + WHERE s.session_token = p_session_token + AND s.expires_at > now() + AND u.is_active = true; + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'Invalid or expired session'::text, NULL::jsonb; + RETURN; + END IF; + + -- Return UserContext + RETURN QUERY SELECT + true, + NULL::text, + jsonb_build_object( + 'user_id', v_user_id, + 'user_name', v_username, + 'email', v_email, + 'user_level', v_user_level, + 'session_id', v_session_id, + 'roles', string_to_array(COALESCE(v_roles, ''), ',') + ); +END; +$$ LANGUAGE plpgsql; + +-- 4. resolvespec_session_update - Updates session activity timestamp +-- Input: sessionid (text), user_context (jsonb) +-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb) +CREATE OR REPLACE FUNCTION resolvespec_session_update(p_session_token text, p_user_context jsonb) +RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$ +DECLARE + v_updated INTEGER; +BEGIN + -- Update last activity timestamp + UPDATE user_sessions + SET last_activity_at = now() + WHERE session_token = p_session_token AND expires_at > now(); + + GET DIAGNOSTICS v_updated = ROW_COUNT; + + IF v_updated = 0 THEN + RETURN QUERY SELECT false, 'Session not found or expired'::text, NULL::jsonb; + ELSE + -- Return the user context as-is + RETURN QUERY SELECT true, NULL::text, p_user_context; + END IF; +END; +$$ LANGUAGE plpgsql; + +-- 5. resolvespec_refresh_token - Generates new session from existing one +-- Input: sessionid (text), user_context (jsonb) +-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb with new session_id) +CREATE OR REPLACE FUNCTION resolvespec_refresh_token(p_old_session_token text, p_user_context jsonb) +RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$ +DECLARE + v_user_id INTEGER; + v_username TEXT; + v_email TEXT; + v_user_level INTEGER; + v_roles TEXT; + v_new_session_token TEXT; + v_expires_at TIMESTAMP; + v_ip_address TEXT; + v_user_agent TEXT; +BEGIN + -- Verify old session exists and is valid + SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent + INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent + FROM user_sessions s + JOIN users u ON s.user_id = u.id + WHERE s.session_token = p_old_session_token + AND s.expires_at > now() + AND u.is_active = true; + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'Invalid or expired refresh token'::text, NULL::jsonb; + RETURN; + END IF; + + -- Generate new session token + v_new_session_token := 'sess_' || encode(gen_random_bytes(32), 'hex') || '_' || extract(epoch from now())::bigint::text; + v_expires_at := now() + interval '24 hours'; + + -- Create new session + INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at) + VALUES (v_new_session_token, v_user_id, v_expires_at, v_ip_address, v_user_agent, now()); + + -- Delete old session + DELETE FROM user_sessions WHERE session_token = p_old_session_token; + + -- Return UserContext with new session_id + RETURN QUERY SELECT + true, + NULL::text, + jsonb_build_object( + 'user_id', v_user_id, + 'user_name', v_username, + 'email', v_email, + 'user_level', v_user_level, + 'session_id', v_new_session_token, + 'roles', string_to_array(COALESCE(v_roles, ''), ',') + ); +END; +$$ LANGUAGE plpgsql; + +-- 6. resolvespec_jwt_login - JWT-based login (queries user and returns data for JWT token generation) +-- Input: username (text), password (text) +-- Output: p_success (bool), p_error (text), p_user (user data as jsonb) +CREATE OR REPLACE FUNCTION resolvespec_jwt_login(p_username text, p_password text) +RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$ +DECLARE + v_user_id INTEGER; + v_username TEXT; + v_email TEXT; + v_password TEXT; + v_user_level INTEGER; + v_roles TEXT; +BEGIN + -- Query user data + SELECT id, username, email, password, user_level, roles + INTO v_user_id, v_username, v_email, v_password, v_user_level, v_roles + FROM users + WHERE username = p_username AND is_active = true; + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb; + RETURN; + END IF; + + -- TODO: Verify password hash + -- IF NOT (crypt(p_password, v_password) = v_password) THEN + -- RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb; + -- RETURN; + -- END IF; + + -- Return user data for JWT token generation + RETURN QUERY SELECT + true, + NULL::text, + jsonb_build_object( + 'id', v_user_id, + 'username', v_username, + 'email', v_email, + 'password', v_password, + 'user_level', v_user_level, + 'roles', v_roles + ); +END; +$$ LANGUAGE plpgsql; + +-- 7. resolvespec_jwt_logout - Adds token to blacklist +-- Input: token (text), user_id (int) +-- Output: p_success (bool), p_error (text) +CREATE OR REPLACE FUNCTION resolvespec_jwt_logout(p_token text, p_user_id integer) +RETURNS TABLE(p_success boolean, p_error text) AS $$ +BEGIN + -- Add token to blacklist + INSERT INTO token_blacklist (token, user_id, expires_at) + VALUES (p_token, p_user_id, now() + interval '24 hours'); + + RETURN QUERY SELECT true, NULL::text; +EXCEPTION + WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM::text; +END; +$$ LANGUAGE plpgsql; + +-- 8. resolvespec_column_security - Loads column security rules for user +-- Input: user_id (int), schema (text), table_name (text) +-- Output: p_success (bool), p_error (text), p_rules (array of security rules as jsonb) +CREATE OR REPLACE FUNCTION resolvespec_column_security(p_user_id integer, p_schema text, p_table_name text) +RETURNS TABLE(p_success boolean, p_error text, p_rules jsonb) AS $$ +DECLARE + v_rules jsonb; +BEGIN + -- Query column security rules from core.secaccess + SELECT jsonb_agg( + jsonb_build_object( + 'control', control, + 'accesstype', accesstype, + 'jsonvalue', jsonvalue + ) + ) + INTO v_rules + FROM core.secaccess + WHERE rid_hub IN ( + SELECT rid_hub_parent + FROM core.hub_link + WHERE rid_hub_child = p_user_id AND parent_hubtype = 'secgroup' + ) + AND control ILIKE (p_schema || '.' || p_table_name || '%'); + + IF v_rules IS NULL THEN + v_rules := '[]'::jsonb; + END IF; + + RETURN QUERY SELECT true, NULL::text, v_rules; +EXCEPTION + WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM::text, '[]'::jsonb; +END; +$$ LANGUAGE plpgsql; + +-- 9. resolvespec_row_security - Loads row security template for user (replaces core.api_sec_rowtemplate) +-- Input: schema (text), table_name (text), user_id (int) +-- Output: p_template (text), p_block (bool) +CREATE OR REPLACE FUNCTION resolvespec_row_security(p_schema text, p_table_name text, p_user_id integer) +RETURNS TABLE(p_template text, p_block boolean) AS $$ +BEGIN + -- Call the existing core function if it exists, or implement your own logic + -- This is a placeholder that you should customize based on your core.api_sec_rowtemplate logic + RETURN QUERY SELECT ''::text, false; + + -- Example implementation: + -- RETURN QUERY SELECT template, has_block + -- FROM core.row_security_config + -- WHERE schema_name = p_schema AND table_name = p_table_name AND user_id = p_user_id; +END; +$$ LANGUAGE plpgsql; + +-- ============================================ +-- Example: Test stored procedures +-- ============================================ + +-- Test login +-- SELECT * FROM resolvespec_login('{"username": "admin", "password": "test123", "claims": {"ip_address": "127.0.0.1", "user_agent": "test"}}'::jsonb); + +-- Test session validation +-- SELECT * FROM resolvespec_session('sess_abc123', 'test_reference'); + +-- Test session update +-- SELECT * FROM resolvespec_session_update('sess_abc123', '{"user_id": 1, "user_name": "admin"}'::jsonb); + +-- Test token refresh +-- SELECT * FROM resolvespec_refresh_token('sess_abc123', '{"user_id": 1, "user_name": "admin"}'::jsonb); + +-- Test logout +-- SELECT * FROM resolvespec_logout('{"token": "sess_abc123", "user_id": 1}'::jsonb); + +-- Test JWT login +-- SELECT * FROM resolvespec_jwt_login('admin', 'password123'); + +-- Test JWT logout +-- SELECT * FROM resolvespec_jwt_logout('jwt_token_here', 1); + +-- Test column security +-- SELECT * FROM resolvespec_column_security(1, 'public', 'users'); + +-- Test row security +-- SELECT * FROM resolvespec_row_security('public', 'users', 1); diff --git a/pkg/security/examples.go b/pkg/security/examples.go new file mode 100644 index 0000000..1260361 --- /dev/null +++ b/pkg/security/examples.go @@ -0,0 +1,380 @@ +package security + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + // Optional: Uncomment if you want to use JWT authentication + // "github.com/golang-jwt/jwt/v5" + "gorm.io/gorm" +) + +// Example 1: Simple Header-Based Authenticator +// ============================================= + +type HeaderAuthenticatorExample struct { + // Optional: Add any dependencies here (e.g., database, cache) +} + +func NewHeaderAuthenticatorExample() *HeaderAuthenticatorExample { + return &HeaderAuthenticatorExample{} +} + +func (a *HeaderAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // For header-based auth, login might not be used + // Could validate credentials against a database here + return nil, fmt.Errorf("header authentication does not support login") +} + +func (a *HeaderAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { + // For header-based auth, logout is a no-op + return nil +} + +func (a *HeaderAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return nil, fmt.Errorf("X-User-ID header required") + } + + userID, err := strconv.Atoi(userIDStr) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + return &UserContext{ + UserID: userID, + UserName: r.Header.Get("X-User-Name"), + UserLevel: parseIntHeader(r, "X-User-Level", 0), + SessionID: r.Header.Get("X-Session-ID"), + RemoteID: r.Header.Get("X-Remote-ID"), + Email: r.Header.Get("X-User-Email"), + Roles: parseRoles(r.Header.Get("X-User-Roles")), + }, nil +} + +// Example 2: JWT Token Authenticator +// ==================================== +// NOTE: To use this, uncomment the jwt import and install: go get github.com/golang-jwt/jwt/v5 + +type JWTAuthenticatorExample struct { + secretKey []byte + db *gorm.DB +} + +func NewJWTAuthenticatorExample(secretKey string, db *gorm.DB) *JWTAuthenticatorExample { + return &JWTAuthenticatorExample{ + secretKey: []byte(secretKey), + db: db, + } +} + +func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // Validate credentials against database + var user struct { + ID int + Username string + Email string + Password string // Should be hashed + UserLevel int + Roles string + } + + err := a.db.WithContext(ctx). + Table("users"). + Where("username = ?", req.Username). + First(&user).Error + if err != nil { + return nil, fmt.Errorf("invalid credentials") + } + + // TODO: Verify password hash + // if !verifyPassword(user.Password, req.Password) { + // return nil, fmt.Errorf("invalid credentials") + // } + + // Create JWT token + expiresAt := time.Now().Add(24 * time.Hour) + + // Uncomment when using JWT: + // token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + // "user_id": user.ID, + // "username": user.Username, + // "email": user.Email, + // "user_level": user.UserLevel, + // "roles": user.Roles, + // "exp": expiresAt.Unix(), + // }) + // tokenString, err := token.SignedString(a.secretKey) + // if err != nil { + // return nil, fmt.Errorf("failed to generate token: %w", err) + // } + + // Placeholder token for example (replace with actual JWT) + tokenString := fmt.Sprintf("token_%d_%d", user.ID, expiresAt.Unix()) + + return &LoginResponse{ + Token: tokenString, + User: &UserContext{ + UserID: user.ID, + UserName: user.Username, + Email: user.Email, + UserLevel: user.UserLevel, + Roles: parseRoles(user.Roles), + }, + ExpiresIn: int64(24 * time.Hour.Seconds()), + }, nil +} + +func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { + // For JWT, logout could involve token blacklisting + // Add token to blacklist table + // err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{ + // "token": req.Token, + // "expires_at": time.Now().Add(24 * time.Hour), + // }).Error + return nil +} + +func (a *JWTAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authorization header required") + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + return nil, fmt.Errorf("bearer token required") + } + + // Uncomment when using JWT: + // token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + // return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + // } + // return a.secretKey, nil + // }) + // + // if err != nil || !token.Valid { + // return nil, fmt.Errorf("invalid token: %w", err) + // } + // + // claims, ok := token.Claims.(jwt.MapClaims) + // if !ok { + // return nil, fmt.Errorf("invalid token claims") + // } + // + // return &UserContext{ + // UserID: int(claims["user_id"].(float64)), + // UserName: getString(claims, "username"), + // Email: getString(claims, "email"), + // UserLevel: getInt(claims, "user_level"), + // Roles: parseRoles(getString(claims, "roles")), + // Claims: claims, + // }, nil + + // Placeholder implementation (replace with actual JWT parsing) + return nil, fmt.Errorf("JWT parsing not implemented - uncomment JWT code above") +} + +// Example 3: Database Session Authenticator +// ========================================== + +type DatabaseAuthenticatorExample struct { + db *gorm.DB +} + +func NewDatabaseAuthenticatorExample(db *gorm.DB) *DatabaseAuthenticatorExample { + return &DatabaseAuthenticatorExample{db: db} +} + +func (a *DatabaseAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // Query user from database + var user struct { + ID int + Username string + Email string + Password string // Should be hashed with bcrypt + UserLevel int + Roles string + IsActive bool + } + + err := a.db.WithContext(ctx). + Table("users"). + Where("username = ? AND is_active = true", req.Username). + First(&user).Error + if err != nil { + return nil, fmt.Errorf("invalid credentials") + } + + // TODO: Verify password with bcrypt + // if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { + // return nil, fmt.Errorf("invalid credentials") + // } + + // Generate session token + sessionToken := fmt.Sprintf("sess_%s_%d", generateRandomString(32), time.Now().Unix()) + expiresAt := time.Now().Add(24 * time.Hour) + + // Create session in database + err = a.db.WithContext(ctx).Table("user_sessions").Create(map[string]any{ + "session_token": sessionToken, + "user_id": user.ID, + "expires_at": expiresAt, + "created_at": time.Now(), + "ip_address": req.Claims["ip_address"], + "user_agent": req.Claims["user_agent"], + }).Error + if err != nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + + return &LoginResponse{ + Token: sessionToken, + User: &UserContext{ + UserID: user.ID, + UserName: user.Username, + Email: user.Email, + UserLevel: user.UserLevel, + Roles: parseRoles(user.Roles), + }, + ExpiresIn: int64(24 * time.Hour.Seconds()), + }, nil +} + +func (a *DatabaseAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { + // Delete session from database + err := a.db.WithContext(ctx). + Table("user_sessions"). + Where("session_token = ? AND user_id = ?", req.Token, req.UserID). + Delete(nil).Error + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + + return nil +} + +func (a *DatabaseAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) { + // Extract session token from header or cookie + sessionToken := r.Header.Get("Authorization") + if sessionToken == "" { + // Try cookie + cookie, err := r.Cookie("session_token") + if err == nil { + sessionToken = cookie.Value + } + } else { + // Remove "Bearer " prefix if present + sessionToken = strings.TrimPrefix(sessionToken, "Bearer ") + } + + if sessionToken == "" { + return nil, fmt.Errorf("session token required") + } + + // Query session and user from database + var session struct { + SessionToken string + UserID int + ExpiresAt time.Time + Username string + Email string + UserLevel int + Roles string + } + + query := ` + SELECT + s.session_token, + s.user_id, + s.expires_at, + u.username, + u.email, + u.user_level, + u.roles + FROM user_sessions s + JOIN users u ON s.user_id = u.id + WHERE s.session_token = ? + AND s.expires_at > ? + AND u.is_active = true + ` + + err := a.db.Raw(query, sessionToken, time.Now()).Scan(&session).Error + if err != nil { + return nil, fmt.Errorf("invalid or expired session") + } + + // Update last activity timestamp + go a.updateSessionActivity(sessionToken) + + return &UserContext{ + UserID: session.UserID, + UserName: session.Username, + Email: session.Email, + UserLevel: session.UserLevel, + SessionID: sessionToken, + Roles: parseRoles(session.Roles), + }, nil +} + +// updateSessionActivity updates the last activity timestamp for the session +func (a *DatabaseAuthenticatorExample) updateSessionActivity(sessionToken string) { + a.db.Table("user_sessions"). + Where("session_token = ?", sessionToken). + Update("last_activity_at", time.Now()) +} + +// Optional: Implement Refreshable interface +func (a *DatabaseAuthenticatorExample) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { + // Query the refresh token + var session struct { + UserID int + Username string + Email string + } + + err := a.db.WithContext(ctx).Raw(` + SELECT u.id as user_id, u.username, u.email + FROM user_sessions s + JOIN users u ON s.user_id = u.id + WHERE s.session_token = ? AND s.expires_at > ? + `, refreshToken, time.Now()).Scan(&session).Error + if err != nil { + return nil, fmt.Errorf("invalid refresh token") + } + + // Generate new session token + newSessionToken := fmt.Sprintf("sess_%s_%d", generateRandomString(32), time.Now().Unix()) + expiresAt := time.Now().Add(24 * time.Hour) + + // Create new session + err = a.db.WithContext(ctx).Table("user_sessions").Create(map[string]any{ + "session_token": newSessionToken, + "user_id": session.UserID, + "expires_at": expiresAt, + "created_at": time.Now(), + }).Error + if err != nil { + return nil, fmt.Errorf("failed to create new session: %w", err) + } + + // Delete old session + a.db.WithContext(ctx).Table("user_sessions").Where("session_token = ?", refreshToken).Delete(nil) + + return &LoginResponse{ + Token: newSessionToken, + User: &UserContext{ + UserID: session.UserID, + UserName: session.Username, + Email: session.Email, + }, + ExpiresIn: int64(24 * time.Hour.Seconds()), + }, nil +} + diff --git a/pkg/security/hooks.go b/pkg/security/hooks.go index 1a2485a..6592875 100644 --- a/pkg/security/hooks.go +++ b/pkg/security/hooks.go @@ -13,25 +13,25 @@ func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *Security // Hook 1: BeforeRead - Load security rules handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error { - return loadSecurityRules(hookCtx, securityList) + return LoadSecurityRules(hookCtx, securityList) }) // Hook 2: BeforeScan - Apply row-level security filters handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error { - return applyRowSecurity(hookCtx, securityList) + return ApplyRowSecurity(hookCtx, securityList) }) // Hook 3: AfterRead - Apply column-level security (masking) handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error { - return applyColumnSecurity(hookCtx, securityList) + return ApplyColumnSecurity(hookCtx, securityList) }) // Hook 4 (Optional): Audit logging - handler.Hooks().Register(restheadspec.AfterRead, logDataAccess) + handler.Hooks().Register(restheadspec.AfterRead, LogDataAccess) } -// loadSecurityRules loads security configuration for the user and entity -func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { +// LoadSecurityRules loads security configuration for the user and entity +func LoadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { // Extract user ID from context userID, ok := GetUserID(hookCtx.Context) if !ok { @@ -44,16 +44,16 @@ func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename) - // Load column security rules from database - err := securityList.LoadColumnSecurity(userID, schema, tablename, false) + // Load column security rules using the provider + err := securityList.LoadColumnSecurity(hookCtx.Context, userID, schema, tablename, false) if err != nil { logger.Warn("Failed to load column security: %v", err) // Don't fail the request if no security rules exist // return err } - // Load row security rules from database - _, err = securityList.LoadRowSecurity(userID, schema, tablename, false) + // Load row security rules using the provider + _, err = securityList.LoadRowSecurity(hookCtx.Context, userID, schema, tablename, false) if err != nil { logger.Warn("Failed to load row security: %v", err) // Don't fail the request if no security rules exist @@ -63,8 +63,8 @@ func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security return nil } -// applyRowSecurity applies row-level security filters to the query -func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { +// ApplyRowSecurity applies row-level security filters to the query +func ApplyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { userID, ok := GetUserID(hookCtx.Context) if !ok { return nil // No user context, skip @@ -130,8 +130,8 @@ func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL return nil } -// applyColumnSecurity applies column-level security (masking/hiding) to results -func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { +// ApplyColumnSecurity applies column-level security (masking/hiding) to results +func ApplyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { userID, ok := GetUserID(hookCtx.Context) if !ok { return nil // No user context, skip @@ -175,8 +175,8 @@ func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *Securi return nil } -// logDataAccess logs all data access for audit purposes -func logDataAccess(hookCtx *restheadspec.HookContext) error { +// LogDataAccess logs all data access for audit purposes +func LogDataAccess(hookCtx *restheadspec.HookContext) error { userID, _ := GetUserID(hookCtx.Context) logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v", diff --git a/pkg/security/interfaces.go b/pkg/security/interfaces.go new file mode 100644 index 0000000..675a8fc --- /dev/null +++ b/pkg/security/interfaces.go @@ -0,0 +1,91 @@ +package security + +import ( + "context" + "net/http" +) + +// UserContext holds authenticated user information +type UserContext struct { + UserID int + UserName string + UserLevel int + SessionID string + RemoteID string + Roles []string + Email string + Claims map[string]any +} + +// LoginRequest contains credentials for login +type LoginRequest struct { + Username string + Password string + Claims map[string]any // Additional login data +} + +// LoginResponse contains the result of a login attempt +type LoginResponse struct { + Token string + RefreshToken string + User *UserContext + ExpiresIn int64 // Token expiration in seconds +} + +// LogoutRequest contains information for logout +type LogoutRequest struct { + Token string + UserID int +} + +// Authenticator handles user authentication operations +type Authenticator interface { + // Login authenticates credentials and returns a token + Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) + + // Logout invalidates a user's session/token + Logout(ctx context.Context, req LogoutRequest) error + + // Authenticate extracts and validates user from HTTP request + // Returns UserContext or error if authentication fails + Authenticate(r *http.Request) (*UserContext, error) +} + +// ColumnSecurityProvider handles column-level security (masking/hiding) +type ColumnSecurityProvider interface { + // GetColumnSecurity loads column security rules for a user and entity + GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) +} + +// RowSecurityProvider handles row-level security (filtering) +type RowSecurityProvider interface { + // GetRowSecurity loads row security rules for a user and entity + GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) +} + +// SecurityProvider is the main interface combining all security concerns +type SecurityProvider interface { + Authenticator + ColumnSecurityProvider + RowSecurityProvider +} + +// Optional interfaces for advanced functionality + +// Refreshable allows providers to support token refresh +type Refreshable interface { + // RefreshToken exchanges a refresh token for a new access token + RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) +} + +// Validatable allows providers to validate tokens without full authentication +type Validatable interface { + // ValidateToken checks if a token is valid without extracting full user context + ValidateToken(ctx context.Context, token string) (bool, error) +} + +// Cacheable allows providers to support caching of security rules +type Cacheable interface { + // ClearCache clears cached security rules for a user/entity + ClearCache(ctx context.Context, userID int, schema, table string) error +} diff --git a/pkg/security/middleware.go b/pkg/security/middleware.go index a44ca19..3d5aaf9 100644 --- a/pkg/security/middleware.go +++ b/pkg/security/middleware.go @@ -10,38 +10,72 @@ type contextKey string const ( // Context keys for user information - UserIDKey contextKey = "user_id" - UserRolesKey contextKey = "user_roles" - UserTokenKey contextKey = "user_token" + UserIDKey contextKey = "user_id" + UserNameKey contextKey = "user_name" + UserLevelKey contextKey = "user_level" + SessionIDKey contextKey = "session_id" + RemoteIDKey contextKey = "remote_id" + UserRolesKey contextKey = "user_roles" + UserEmailKey contextKey = "user_email" + UserContextKey contextKey = "user_context" ) -// AuthMiddleware extracts user authentication from request and adds to context -// This should be applied before the ResolveSpec handler -// Uses GlobalSecurity.AuthenticateCallback if set, otherwise returns error -func AuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if callback is set - if GlobalSecurity.AuthenticateCallback == nil { - http.Error(w, "AuthenticateCallback not set - you must provide an authentication callback", http.StatusInternalServerError) - return - } +// NewAuthMiddleware creates an authentication middleware with the given security list +// This middleware extracts user authentication from the request and adds it to context +func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the security provider + provider := securityList.Provider() + if provider == nil { + http.Error(w, "Security provider not configured", http.StatusInternalServerError) + return + } - // Call the user-provided authentication callback - userID, roles, err := GlobalSecurity.AuthenticateCallback(r) - if err != nil { - http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized) - return - } + // Call the provider's Authenticate method + userCtx, err := provider.Authenticate(r) + if err != nil { + http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized) + return + } - // Add user information to context - ctx := context.WithValue(r.Context(), UserIDKey, userID) - if roles != "" { - ctx = context.WithValue(ctx, UserRolesKey, roles) - } + // Add user information to context + ctx := r.Context() + ctx = context.WithValue(ctx, UserContextKey, userCtx) + ctx = context.WithValue(ctx, UserIDKey, userCtx.UserID) + ctx = context.WithValue(ctx, UserNameKey, userCtx.UserName) + ctx = context.WithValue(ctx, UserLevelKey, userCtx.UserLevel) + ctx = context.WithValue(ctx, SessionIDKey, userCtx.SessionID) + ctx = context.WithValue(ctx, RemoteIDKey, userCtx.RemoteID) - // Continue with authenticated context - next.ServeHTTP(w, r.WithContext(ctx)) - }) + if len(userCtx.Roles) > 0 { + ctx = context.WithValue(ctx, UserRolesKey, userCtx.Roles) + } + if userCtx.Email != "" { + ctx = context.WithValue(ctx, UserEmailKey, userCtx.Email) + } + + // Continue with authenticated context + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// SetSecurityMiddleware adds security context to requests +// This middleware should be applied after AuthMiddleware +func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, securityList) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// GetUserContext extracts the full user context from request context +func GetUserContext(ctx context.Context) (*UserContext, bool) { + userCtx, ok := ctx.Value(UserContextKey).(*UserContext) + return userCtx, ok } // GetUserID extracts the user ID from context @@ -50,8 +84,38 @@ func GetUserID(ctx context.Context) (int, bool) { return userID, ok } +// GetUserName extracts the user name from context +func GetUserName(ctx context.Context) (string, bool) { + userName, ok := ctx.Value(UserNameKey).(string) + return userName, ok +} + +// GetUserLevel extracts the user level from context +func GetUserLevel(ctx context.Context) (int, bool) { + userLevel, ok := ctx.Value(UserLevelKey).(int) + return userLevel, ok +} + +// GetSessionID extracts the session ID from context +func GetSessionID(ctx context.Context) (string, bool) { + sessionID, ok := ctx.Value(SessionIDKey).(string) + return sessionID, ok +} + +// GetRemoteID extracts the remote ID from context +func GetRemoteID(ctx context.Context) (string, bool) { + remoteID, ok := ctx.Value(RemoteIDKey).(string) + return remoteID, ok +} + // GetUserRoles extracts user roles from context -func GetUserRoles(ctx context.Context) (string, bool) { - roles, ok := ctx.Value(UserRolesKey).(string) +func GetUserRoles(ctx context.Context) ([]string, bool) { + roles, ok := ctx.Value(UserRolesKey).([]string) return roles, ok } + +// GetUserEmail extracts user email from context +func GetUserEmail(ctx context.Context) (string, bool) { + email, ok := ctx.Value(UserEmailKey).(string) + return email, ok +} diff --git a/pkg/security/provider.go b/pkg/security/provider.go index 4e05cab..788e603 100644 --- a/pkg/security/provider.go +++ b/pkg/security/provider.go @@ -3,7 +3,6 @@ package security import ( "context" "fmt" - "net/http" "reflect" "strings" "sync" @@ -47,46 +46,39 @@ func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Typ return str } -// Callback function types for customizing security behavior -type ( - // AuthenticateFunc extracts user ID and roles from HTTP request - // Return userID, roles, error. If error is not nil, request will be rejected. - AuthenticateFunc func(r *http.Request) (userID int, roles string, err error) - - // LoadColumnSecurityFunc loads column security rules for a user and entity - // Override this to customize how column security is loaded from your data source - LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) - - // LoadRowSecurityFunc loads row security rules for a user and entity - // Override this to customize how row security is loaded from your data source - LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error) -) - +// SecurityList manages security state and caching +// It wraps a SecurityProvider and provides caching and utility methods type SecurityList struct { + provider SecurityProvider + ColumnSecurityMutex sync.RWMutex ColumnSecurity map[string][]ColumnSecurity RowSecurityMutex sync.RWMutex RowSecurity map[string]RowSecurity - - // Overridable callbacks - AuthenticateCallback AuthenticateFunc - LoadColumnSecurityCallback LoadColumnSecurityFunc - LoadRowSecurityCallback LoadRowSecurityFunc } + +// NewSecurityList creates a new security list with the given provider +func NewSecurityList(provider SecurityProvider) *SecurityList { + if provider == nil { + panic("security provider cannot be nil") + } + + return &SecurityList{ + provider: provider, + ColumnSecurity: make(map[string][]ColumnSecurity), + RowSecurity: make(map[string]RowSecurity), + } +} + +// Provider returns the underlying security provider +func (m *SecurityList) Provider() SecurityProvider { + return m.provider +} + type CONTEXT_KEY string const SECURITY_CONTEXT_KEY CONTEXT_KEY = "SecurityList" -var GlobalSecurity SecurityList - -// SetSecurityMiddleware adds security context to requests -func SetSecurityMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string { strLen := len(pString) middleIndex := (strLen / 2) @@ -372,10 +364,9 @@ func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType refl return records, nil } -func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error { - // Use the callback if provided - if m.LoadColumnSecurityCallback == nil { - return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function") +func (m *SecurityList) LoadColumnSecurity(ctx context.Context, pUserID int, pSchema, pTablename string, pOverwrite bool) error { + if m.provider == nil { + return fmt.Errorf("security provider not set") } m.ColumnSecurityMutex.Lock() @@ -390,10 +381,10 @@ func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename strin m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0) } - // Call the user-provided callback to load security rules - colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename) + // Call the provider to load security rules + colSecList, err := m.provider.GetColumnSecurity(ctx, pUserID, pSchema, pTablename) if err != nil { - return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err) + return fmt.Errorf("GetColumnSecurity failed: %v", err) } m.ColumnSecurity[secKey] = colSecList @@ -422,10 +413,9 @@ func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) er return nil } -func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) { - // Use the callback if provided - if m.LoadRowSecurityCallback == nil { - return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function") +func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) { + if m.provider == nil { + return RowSecurity{}, fmt.Errorf("security provider not set") } m.RowSecurityMutex.Lock() @@ -436,10 +426,10 @@ func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, } secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) - // Call the user-provided callback to load security rules - record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename) + // Call the provider to load security rules + record, err := m.provider.GetRowSecurity(ctx, pUserID, pSchema, pTablename) if err != nil { - return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err) + return RowSecurity{}, fmt.Errorf("GetRowSecurity failed: %v", err) } m.RowSecurity[secKey] = record diff --git a/pkg/security/providers.go b/pkg/security/providers.go new file mode 100644 index 0000000..ce00750 --- /dev/null +++ b/pkg/security/providers.go @@ -0,0 +1,552 @@ +package security + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +// Production-Ready Authenticators +// ================================= + +// HeaderAuthenticator provides simple header-based authentication +// Expects: X-User-ID, X-User-Name, X-User-Level, X-Session-ID, X-Remote-ID, X-User-Roles, X-User-Email +type HeaderAuthenticator struct{} + +func NewHeaderAuthenticator() *HeaderAuthenticator { + return &HeaderAuthenticator{} +} + +func (a *HeaderAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + return nil, fmt.Errorf("header authentication does not support login") +} + +func (a *HeaderAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { + return nil +} + +func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return nil, fmt.Errorf("X-User-ID header required") + } + + userID, err := strconv.Atoi(userIDStr) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + return &UserContext{ + UserID: userID, + UserName: r.Header.Get("X-User-Name"), + UserLevel: parseIntHeader(r, "X-User-Level", 0), + SessionID: r.Header.Get("X-Session-ID"), + RemoteID: r.Header.Get("X-Remote-ID"), + Email: r.Header.Get("X-User-Email"), + Roles: parseRoles(r.Header.Get("X-User-Roles")), + }, nil +} + +// DatabaseAuthenticator provides session-based authentication with database storage +// All database operations go through stored procedures for security and consistency +// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, +// resolvespec_session_update, resolvespec_refresh_token +// See database_schema.sql for procedure definitions +type DatabaseAuthenticator struct { + db *sql.DB +} + +func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { + return &DatabaseAuthenticator{db: db} +} + +func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // Convert LoginRequest to JSON + reqJSON, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal login request: %w", err) + } + + // Call resolvespec_login stored procedure + var success bool + var errorMsg sql.NullString + var dataJSON []byte + + query := `SELECT p_success, p_error, p_data FROM resolvespec_login($1::jsonb)` + err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON) + if err != nil { + return nil, fmt.Errorf("login query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("login failed") + } + + // Parse response + var response LoginResponse + if err := json.Unmarshal(dataJSON, &response); err != nil { + return nil, fmt.Errorf("failed to parse login response: %w", err) + } + + return &response, nil +} + +func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { + // Convert LogoutRequest to JSON + reqJSON, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal logout request: %w", err) + } + + // Call resolvespec_logout stored procedure + var success bool + var errorMsg sql.NullString + var dataJSON []byte + + query := `SELECT p_success, p_error, p_data FROM resolvespec_logout($1::jsonb)` + err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON) + if err != nil { + return fmt.Errorf("logout query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return fmt.Errorf("%s", errorMsg.String) + } + return fmt.Errorf("logout failed") + } + + return nil +} + +func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { + // Extract session token from header or cookie + sessionToken := r.Header.Get("Authorization") + if sessionToken == "" { + // Try cookie + cookie, err := r.Cookie("session_token") + if err == nil { + sessionToken = cookie.Value + } + } else { + // Remove "Bearer " prefix if present + sessionToken = strings.TrimPrefix(sessionToken, "Bearer ") + } + + if sessionToken == "" { + return nil, fmt.Errorf("session token required") + } + + // Call resolvespec_session stored procedure + // reference could be route, controller name, or any identifier + reference := "authenticate" + + var success bool + var errorMsg sql.NullString + var userJSON []byte + + query := `SELECT p_success, p_error, p_user FROM resolvespec_session($1, $2)` + err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON) + if err != nil { + return nil, fmt.Errorf("session query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("invalid or expired session") + } + + // Parse UserContext + var userCtx UserContext + if err := json.Unmarshal(userJSON, &userCtx); err != nil { + return nil, fmt.Errorf("failed to parse user context: %w", err) + } + + // Update last activity timestamp asynchronously + go a.updateSessionActivity(r.Context(), sessionToken, &userCtx) + + return &userCtx, nil +} + +// updateSessionActivity updates the last activity timestamp for the session +func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessionToken string, userCtx *UserContext) { + // Convert UserContext to JSON + userJSON, err := json.Marshal(userCtx) + if err != nil { + return + } + + // Call resolvespec_session_update stored procedure + var success bool + var errorMsg sql.NullString + var updatedUserJSON []byte + + query := `SELECT p_success, p_error, p_user FROM resolvespec_session_update($1, $2::jsonb)` + a.db.QueryRowContext(ctx, query, sessionToken, userJSON).Scan(&success, &errorMsg, &updatedUserJSON) +} + +// RefreshToken implements Refreshable interface +func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { + // Call api_refresh_token stored procedure + // First, we need to get the current user context for the refresh token + var success bool + var errorMsg sql.NullString + var userJSON []byte + + // Get current session to pass to refresh + query := `SELECT p_success, p_error, p_user FROM resolvespec_session($1, $2)` + err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) + if err != nil { + return nil, fmt.Errorf("refresh token query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("invalid refresh token") + } + + // Call resolvespec_refresh_token to generate new token + var newSuccess bool + var newErrorMsg sql.NullString + var newUserJSON []byte + + refreshQuery := `SELECT p_success, p_error, p_user FROM resolvespec_refresh_token($1, $2::jsonb)` + err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) + if err != nil { + return nil, fmt.Errorf("refresh token generation failed: %w", err) + } + + if !newSuccess { + if newErrorMsg.Valid { + return nil, fmt.Errorf("%s", newErrorMsg.String) + } + return nil, fmt.Errorf("failed to refresh token") + } + + // Parse refreshed user context + var userCtx UserContext + if err := json.Unmarshal(newUserJSON, &userCtx); err != nil { + return nil, fmt.Errorf("failed to parse user context: %w", err) + } + + return &LoginResponse{ + Token: userCtx.SessionID, // New session token from stored procedure + User: &userCtx, + ExpiresIn: int64(24 * time.Hour.Seconds()), + }, nil +} + +// JWTAuthenticator provides JWT token-based authentication +// All database operations go through stored procedures +// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout +// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported +type JWTAuthenticator struct { + secretKey []byte + db *sql.DB +} + +func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator { + return &JWTAuthenticator{ + secretKey: []byte(secretKey), + db: db, + } +} + +func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // Call resolvespec_jwt_login stored procedure + var success bool + var errorMsg sql.NullString + var userJSON []byte + + query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)` + err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) + if err != nil { + return nil, fmt.Errorf("login query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("invalid credentials") + } + + // Parse user data + var user struct { + ID int `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` + UserLevel int `json:"user_level"` + Roles string `json:"roles"` + } + + if err := json.Unmarshal(userJSON, &user); err != nil { + return nil, fmt.Errorf("failed to parse user data: %w", err) + } + + // TODO: Verify password + // if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { + // return nil, fmt.Errorf("invalid credentials") + // } + + // Generate token (placeholder - implement JWT signing when library is available) + expiresAt := time.Now().Add(24 * time.Hour) + tokenString := fmt.Sprintf("token_%d_%d", user.ID, expiresAt.Unix()) + + return &LoginResponse{ + Token: tokenString, + User: &UserContext{ + UserID: user.ID, + UserName: user.Username, + Email: user.Email, + UserLevel: user.UserLevel, + Roles: parseRoles(user.Roles), + }, + ExpiresIn: int64(24 * time.Hour.Seconds()), + }, nil +} + +func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { + // Call resolvespec_jwt_logout stored procedure + var success bool + var errorMsg sql.NullString + + query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)` + err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) + if err != nil { + return fmt.Errorf("logout query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return fmt.Errorf("%s", errorMsg.String) + } + return fmt.Errorf("logout failed") + } + + return nil +} + +func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("authorization header required") + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + return nil, fmt.Errorf("bearer token required") + } + + // TODO: Implement JWT parsing when library is available + return nil, fmt.Errorf("JWT parsing not implemented - install github.com/golang-jwt/jwt/v5") +} + +// Production-Ready Security Providers +// ==================================== + +// DatabaseColumnSecurityProvider loads column security from database +// All database operations go through stored procedures +// Requires stored procedure: resolvespec_column_security +type DatabaseColumnSecurityProvider struct { + db *sql.DB +} + +func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider { + return &DatabaseColumnSecurityProvider{db: db} +} + +func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { + var rules []ColumnSecurity + + // Call resolvespec_column_security stored procedure + var success bool + var errorMsg sql.NullString + var rulesJSON []byte + + query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)` + err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) + if err != nil { + return nil, fmt.Errorf("failed to load column security: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("failed to load column security") + } + + // Parse the JSON array of security records + type SecurityRecord struct { + Control string `json:"control"` + Accesstype string `json:"accesstype"` + JSONValue string `json:"jsonvalue"` + } + + var records []SecurityRecord + if err := json.Unmarshal(rulesJSON, &records); err != nil { + return nil, fmt.Errorf("failed to parse security rules: %w", err) + } + + // Convert records to ColumnSecurity rules + for _, rec := range records { + parts := strings.Split(rec.Control, ".") + if len(parts) < 3 { + continue + } + + rule := ColumnSecurity{ + Schema: schema, + Tablename: table, + Path: parts[2:], + Accesstype: rec.Accesstype, + UserID: userID, + } + + rules = append(rules, rule) + } + + return rules, nil +} + +// DatabaseRowSecurityProvider loads row security from database +// All database operations go through stored procedures +// Requires stored procedure: resolvespec_row_security +type DatabaseRowSecurityProvider struct { + db *sql.DB +} + +func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider { + return &DatabaseRowSecurityProvider{db: db} +} + +func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { + var template string + var hasBlock bool + + // Call resolvespec_row_security stored procedure + query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)` + + err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) + if err != nil { + return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err) + } + + return RowSecurity{ + Schema: schema, + Tablename: table, + UserID: userID, + Template: template, + HasBlock: hasBlock, + }, nil +} + +// ConfigColumnSecurityProvider provides static column security configuration +type ConfigColumnSecurityProvider struct { + rules map[string][]ColumnSecurity +} + +func NewConfigColumnSecurityProvider(rules map[string][]ColumnSecurity) *ConfigColumnSecurityProvider { + return &ConfigColumnSecurityProvider{rules: rules} +} + +func (p *ConfigColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { + key := fmt.Sprintf("%s.%s", schema, table) + rules, ok := p.rules[key] + if !ok { + return []ColumnSecurity{}, nil + } + return rules, nil +} + +// ConfigRowSecurityProvider provides static row security configuration +type ConfigRowSecurityProvider struct { + templates map[string]string + blocked map[string]bool +} + +func NewConfigRowSecurityProvider(templates map[string]string, blocked map[string]bool) *ConfigRowSecurityProvider { + return &ConfigRowSecurityProvider{ + templates: templates, + blocked: blocked, + } +} + +func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { + key := fmt.Sprintf("%s.%s", schema, table) + + if p.blocked[key] { + return RowSecurity{ + Schema: schema, + Tablename: table, + UserID: userID, + HasBlock: true, + }, nil + } + + template := p.templates[key] + return RowSecurity{ + Schema: schema, + Tablename: table, + UserID: userID, + Template: template, + HasBlock: false, + }, nil +} + +// Helper functions +// ================ + +func parseRoles(rolesStr string) []string { + if rolesStr == "" { + return []string{} + } + return strings.Split(rolesStr, ",") +} + +func parseIntHeader(r *http.Request, key string, defaultVal int) int { + val := r.Header.Get(key) + if val == "" { + return defaultVal + } + intVal, err := strconv.Atoi(val) + if err != nil { + return defaultVal + } + return intVal +} + +func generateRandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[time.Now().UnixNano()%int64(len(charset))] + } + return string(b) +} + +func getClaimString(claims map[string]any, key string) string { + if claims == nil { + return "" + } + if val, ok := claims[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return "" +} diff --git a/pkg/security/setup_example.go b/pkg/security/setup_example.go index 7300e9d..49e503a 100644 --- a/pkg/security/setup_example.go +++ b/pkg/security/setup_example.go @@ -1,155 +1,292 @@ package security import ( + "context" + "database/sql" "fmt" "net/http" "github.com/gorilla/mux" - "gorm.io/gorm" "github.com/bitechdev/ResolveSpec/pkg/restheadspec" ) // SetupSecurityProvider initializes and configures the security provider -// This should be called when setting up your HTTP server +// This function creates a SecurityList with the given provider and registers hooks // -// IMPORTANT: You MUST configure the callbacks before calling this function: -// - GlobalSecurity.AuthenticateCallback -// - GlobalSecurity.LoadColumnSecurityCallback -// - GlobalSecurity.LoadRowSecurityCallback +// Example usage: // -// Example usage in your main.go or server setup: +// // Create your security provider (use composite or single provider) +// auth := security.NewJWTAuthenticator("your-secret-key", db) +// colSec := security.NewDatabaseColumnSecurityProvider(db) +// rowSec := security.NewDatabaseRowSecurityProvider(db) +// provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec) // -// // Step 1: Configure callbacks (REQUIRED) -// security.GlobalSecurity.AuthenticateCallback = myAuthFunction -// security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurityFunction -// security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurityFunction -// -// // Step 2: Setup security provider +// // Setup security with the provider // handler := restheadspec.NewHandlerWithGORM(db) -// security.SetupSecurityProvider(handler, &security.GlobalSecurity) +// securityList := security.SetupSecurityProvider(handler, provider) // -// // Step 3: Apply middleware -// router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) -// router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) -func SetupSecurityProvider(handler *restheadspec.Handler, securityList *SecurityList) error { - // Validate that required callbacks are configured - if securityList.AuthenticateCallback == nil { - return fmt.Errorf("AuthenticateCallback must be set before calling SetupSecurityProvider") - } - if securityList.LoadColumnSecurityCallback == nil { - return fmt.Errorf("LoadColumnSecurityCallback must be set before calling SetupSecurityProvider") - } - if securityList.LoadRowSecurityCallback == nil { - return fmt.Errorf("LoadRowSecurityCallback must be set before calling SetupSecurityProvider") +// // Apply middleware +// router.Use(security.NewAuthMiddleware(securityList)) +// router.Use(security.SetSecurityMiddleware(securityList)) +func SetupSecurityProvider(handler *restheadspec.Handler, provider SecurityProvider) *SecurityList { + if provider == nil { + panic("security provider cannot be nil") } - // Initialize security maps if needed - if securityList.ColumnSecurity == nil { - securityList.ColumnSecurity = make(map[string][]ColumnSecurity) - } - if securityList.RowSecurity == nil { - securityList.RowSecurity = make(map[string]RowSecurity) - } + // Create security list with the provider + securityList := NewSecurityList(provider) // Register all security hooks RegisterSecurityHooks(handler, securityList) - return nil + return securityList } -// Chain creates a middleware chain -func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { - return func(final http.Handler) http.Handler { - for i := len(middlewares) - 1; i >= 0; i-- { - final = middlewares[i](final) - } - return final - } -} +// Example 1: Complete Setup with Composite Provider and Database-Backed Security +// =============================================================================== +// Note: Security providers use *sql.DB, but restheadspec.Handler may use *gorm.DB +// You can get *sql.DB from gorm.DB using: sqlDB, _ := gormDB.DB() -// CompleteExample shows a full integration example with Gorilla Mux -func CompleteExample(db *gorm.DB) (http.Handler, error) { +func ExampleDatabaseSecurity(gormDB interface{}, sqlDB *sql.DB) (http.Handler, error) { // Step 1: Create the ResolveSpec handler - handler := restheadspec.NewHandlerWithGORM(db) + // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) + handler := &restheadspec.Handler{} // Placeholder - use your handler initialization // Step 2: Register your models // handler.RegisterModel("public", "users", User{}) // handler.RegisterModel("public", "orders", Order{}) - // Step 3: Configure security callbacks (REQUIRED!) - // See callbacks_example.go for example implementations - GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader - GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase - GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase + // Step 3: Create security provider components (using sql.DB) + auth := NewJWTAuthenticator("your-secret-key", sqlDB) + colSec := NewDatabaseColumnSecurityProvider(sqlDB) + rowSec := NewDatabaseRowSecurityProvider(sqlDB) - // Step 4: Setup security provider - if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil { - return nil, fmt.Errorf("failed to setup security: %v", err) - } + // Step 4: Combine into composite provider + provider := NewCompositeSecurityProvider(auth, colSec, rowSec) - // Step 5: Create Mux router and setup routes + // Step 5: Setup security + securityList := SetupSecurityProvider(handler, provider) + + // Step 6: Create router and setup routes router := mux.NewRouter() - - // The routes are set up by restheadspec, which handles the conversion - // from http.Request to the internal request format restheadspec.SetupMuxRoutes(router, handler) - // Step 6: Apply middleware to the entire router - secureRouter := Chain( - AuthMiddleware, // Extract user from token - SetSecurityMiddleware, // Add security context - )(router) - - return secureRouter, nil -} - -// ExampleWithMux shows a simpler integration with Mux -func ExampleWithMux(db *gorm.DB) (*mux.Router, error) { - handler := restheadspec.NewHandlerWithGORM(db) - - // IMPORTANT: Configure callbacks BEFORE SetupSecurityProvider - GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader - GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig - GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig - - if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil { - return nil, fmt.Errorf("failed to setup security: %v", err) - } - - router := mux.NewRouter() - - // Setup API routes - restheadspec.SetupMuxRoutes(router, handler) - - // Apply middleware to router - router.Use(mux.MiddlewareFunc(AuthMiddleware)) - router.Use(mux.MiddlewareFunc(SetSecurityMiddleware)) + // Step 7: Apply middleware in correct order + router.Use(NewAuthMiddleware(securityList)) + router.Use(SetSecurityMiddleware(securityList)) return router, nil } -// Example with Gin -// import "github.com/gin-gonic/gin" -// -// func ExampleWithGin(db *gorm.DB) *gin.Engine { -// handler := restheadspec.NewHandlerWithGORM(db) -// SetupSecurityProvider(handler, &GlobalSecurity) -// -// router := gin.Default() -// -// // Convert middleware to Gin middleware -// router.Use(func(c *gin.Context) { -// AuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// c.Request = r -// c.Next() -// })).ServeHTTP(c.Writer, c.Request) -// }) -// -// // Setup routes -// api := router.Group("/api") -// api.Any("/:schema/:entity", gin.WrapH(http.HandlerFunc(handler.Handle))) -// api.Any("/:schema/:entity/:id", gin.WrapH(http.HandlerFunc(handler.Handle))) -// -// return router -// } +// Example 2: Simple Header-Based Authentication +// ============================================== + +func ExampleHeaderAuthentication(gormDB interface{}, sqlDB *sql.DB) (*mux.Router, error) { + // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) + handler := &restheadspec.Handler{} // Placeholder - use your handler initialization + + // Use header-based auth with database security providers + auth := NewHeaderAuthenticatorExample() + colSec := NewDatabaseColumnSecurityProvider(sqlDB) + rowSec := NewDatabaseRowSecurityProvider(sqlDB) + + provider := NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := SetupSecurityProvider(handler, provider) + + router := mux.NewRouter() + restheadspec.SetupMuxRoutes(router, handler) + + router.Use(NewAuthMiddleware(securityList)) + router.Use(SetSecurityMiddleware(securityList)) + + return router, nil +} + +// Example 3: Config-Based Security (No Database for Security) +// =========================================================== + +func ExampleConfigSecurity(gormDB interface{}) (*mux.Router, error) { + // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) + handler := &restheadspec.Handler{} // Placeholder - use your handler initialization + + // Define column security rules in code + columnRules := map[string][]ColumnSecurity{ + "public.employees": { + { + Schema: "public", + Tablename: "employees", + Path: []string{"ssn"}, + Accesstype: "mask", + MaskStart: 5, + MaskChar: "*", + }, + { + Schema: "public", + Tablename: "employees", + Path: []string{"salary"}, + Accesstype: "hide", + }, + }, + } + + // Define row security templates + rowTemplates := map[string]string{ + "public.orders": "user_id = {UserID}", + "public.documents": "user_id = {UserID} OR is_public = true", + } + + // Define blocked tables + blockedTables := map[string]bool{ + "public.admin_logs": true, + } + + // Create providers + auth := NewHeaderAuthenticatorExample() + colSec := NewConfigColumnSecurityProvider(columnRules) + rowSec := NewConfigRowSecurityProvider(rowTemplates, blockedTables) + + provider := NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := SetupSecurityProvider(handler, provider) + + router := mux.NewRouter() + restheadspec.SetupMuxRoutes(router, handler) + + router.Use(NewAuthMiddleware(securityList)) + router.Use(SetSecurityMiddleware(securityList)) + + return router, nil +} + +// Example 4: Custom Security Provider +// ==================================== + +// You can implement your own SecurityProvider by implementing all three interfaces +type CustomSecurityProvider struct { + // Your custom fields +} + +func (p *CustomSecurityProvider) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { + // Your custom login logic + return nil, fmt.Errorf("not implemented") +} + +func (p *CustomSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error { + // Your custom logout logic + return nil +} + +func (p *CustomSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) { + // Your custom authentication logic + return nil, fmt.Errorf("not implemented") +} + +func (p *CustomSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { + // Your custom column security logic + return []ColumnSecurity{}, nil +} + +func (p *CustomSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { + // Your custom row security logic + return RowSecurity{ + Schema: schema, + Tablename: table, + UserID: userID, + }, nil +} + +// Example 5: Adding Login/Logout Endpoints +// ========================================= + +func SetupAuthRoutes(router *mux.Router, securityList *SecurityList) { + // Login endpoint + router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) { + // Parse login request + var loginReq LoginRequest + // json.NewDecoder(r.Body).Decode(&loginReq) + + // Call provider's Login method + resp, err := securityList.Provider().Login(r.Context(), loginReq) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + // Return token + w.Header().Set("Content-Type", "application/json") + // json.NewEncoder(w).Encode(resp) + fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn) + }).Methods("POST") + + // Logout endpoint + router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) { + // Extract token from header + token := r.Header.Get("Authorization") + + // Get user ID from context (if authenticated) + userID, _ := GetUserID(r.Context()) + + // Call provider's Logout method + err := securityList.Provider().Logout(r.Context(), LogoutRequest{ + Token: token, + UserID: userID, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"success": true}`) + }).Methods("POST") + + // Optional: Token refresh endpoint + router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) { + refreshToken := r.Header.Get("X-Refresh-Token") + + // Check if provider supports refresh + if refreshable, ok := securityList.Provider().(Refreshable); ok { + resp, err := refreshable.RefreshToken(r.Context(), refreshToken) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"token": "%s", "expires_in": %d}`, resp.Token, resp.ExpiresIn) + } else { + http.Error(w, "Token refresh not supported", http.StatusNotImplemented) + } + }).Methods("POST") +} + +// Example 6: Complete Server Setup +// ================================= + +func CompleteServerExample(gormDB interface{}, sqlDB *sql.DB) http.Handler { + // Create handler and register models + // handler := restheadspec.NewHandlerWithGORM(gormDB.(*gorm.DB)) + handler := &restheadspec.Handler{} // Placeholder - use your handler initialization + // handler.RegisterModel("public", "users", User{}) + + // Setup security (using sql.DB for security providers) + auth := NewJWTAuthenticator("secret-key", sqlDB) + colSec := NewDatabaseColumnSecurityProvider(sqlDB) + rowSec := NewDatabaseRowSecurityProvider(sqlDB) + provider := NewCompositeSecurityProvider(auth, colSec, rowSec) + securityList := SetupSecurityProvider(handler, provider) + + // Create router + router := mux.NewRouter() + + // Add auth routes (login/logout) + SetupAuthRoutes(router, securityList) + + // Add API routes with security middleware + apiRouter := router.PathPrefix("/api").Subrouter() + restheadspec.SetupMuxRoutes(apiRouter, handler) + apiRouter.Use(NewAuthMiddleware(securityList)) + apiRouter.Use(SetSecurityMiddleware(securityList)) + + return router +}