From 03086440757f5ea986657caa8b6144e638279f24 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 21 May 2026 11:27:51 +0200 Subject: [PATCH] feat(auth): add authenticate callback for fallback logic * Implement SetAuthenticateCallback in authenticators * Update Authenticate methods to use callback on failure --- pkg/security/chain.go | 18 ++++++++++++- pkg/security/chain_test.go | 2 ++ pkg/security/composite.go | 5 ++++ pkg/security/composite_test.go | 2 ++ pkg/security/interfaces.go | 4 +++ pkg/security/keystore_authenticator.go | 26 +++++++++++++++++-- pkg/security/provider_test.go | 2 ++ pkg/security/providers.go | 36 +++++++++++++++++++------- pkg/security/totp_integration_test.go | 2 ++ 9 files changed, 85 insertions(+), 12 deletions(-) diff --git a/pkg/security/chain.go b/pkg/security/chain.go index dbab0ac..8a3c158 100644 --- a/pkg/security/chain.go +++ b/pkg/security/chain.go @@ -9,7 +9,8 @@ import ( // ChainAuthenticator tries each authenticator in order, returning the first success. // Login and Logout are delegated to the primary authenticator. type ChainAuthenticator struct { - authenticators []Authenticator + authenticators []Authenticator + authenticateCallback func(r *http.Request) (*UserContext, error) } // NewChainAuthenticator creates a ChainAuthenticator from the given authenticators. @@ -29,13 +30,28 @@ func (c *ChainAuthenticator) Authenticate(r *http.Request) (*UserContext, error) lastErr = err } } + if c.authenticateCallback != nil { + return c.authenticateCallback(r) + } return nil, fmt.Errorf("all authenticators failed; last error: %w", lastErr) } +func (c *ChainAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) { + c.authenticateCallback = fn +} + func (c *ChainAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { return c.authenticators[0].Login(ctx, req) } +func (c *ChainAuthenticator) LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error) { + return c.authenticators[0].LoginWithCookie(ctx, req, w) +} + func (c *ChainAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { return c.authenticators[0].Logout(ctx, req) } + +func (c *ChainAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error { + return c.authenticators[0].LogoutWithCookie(ctx, req, w) +} diff --git a/pkg/security/chain_test.go b/pkg/security/chain_test.go index 5240fd7..48a397f 100644 --- a/pkg/security/chain_test.go +++ b/pkg/security/chain_test.go @@ -37,6 +37,8 @@ func (s *stubAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequ return s.Logout(ctx, req) } +func (s *stubAuthenticator) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {} + func TestChainAuthenticator_Authenticate(t *testing.T) { successCtx := &UserContext{UserID: 42, UserName: "alice"} failStub := &stubAuthenticator{err: fmt.Errorf("no token")} diff --git a/pkg/security/composite.go b/pkg/security/composite.go index 6808020..088fe0c 100644 --- a/pkg/security/composite.go +++ b/pkg/security/composite.go @@ -63,6 +63,11 @@ func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext, return c.auth.Authenticate(r) } +// SetAuthenticateCallback delegates to the authenticator +func (c *CompositeSecurityProvider) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) { + c.auth.SetAuthenticateCallback(fn) +} + // GetColumnSecurity delegates to the column security provider func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { return c.colSec.GetColumnSecurity(ctx, userID, schema, table) diff --git a/pkg/security/composite_test.go b/pkg/security/composite_test.go index c9a3ea2..a4d376f 100644 --- a/pkg/security/composite_test.go +++ b/pkg/security/composite_test.go @@ -39,6 +39,8 @@ func (m *mockAuth) Authenticate(r *http.Request) (*UserContext, error) { return m.authUser, m.authErr } +func (m *mockAuth) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {} + // Optional interface implementations func (m *mockAuth) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { if !m.supportsRefresh { diff --git a/pkg/security/interfaces.go b/pkg/security/interfaces.go index 395b0e1..a85987e 100644 --- a/pkg/security/interfaces.go +++ b/pkg/security/interfaces.go @@ -99,6 +99,10 @@ type Authenticator interface { // Authenticate extracts and validates user from HTTP request // Returns UserContext or error if authentication fails Authenticate(r *http.Request) (*UserContext, error) + + // SetAuthenticateCallback registers a fallback called when primary authentication fails. + // If the callback returns a non-nil UserContext, that result is used instead of the error. + SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) } // Registrable allows providers to support user registration diff --git a/pkg/security/keystore_authenticator.go b/pkg/security/keystore_authenticator.go index dd8cae2..567261e 100644 --- a/pkg/security/keystore_authenticator.go +++ b/pkg/security/keystore_authenticator.go @@ -17,8 +17,9 @@ import ( // 2. Authorization: ApiKey // 3. X-API-Key header type KeyStoreAuthenticator struct { - keyStore KeyStore - keyType KeyType // empty = accept any type + keyStore KeyStore + keyType KeyType // empty = accept any type + authenticateCallback func(r *http.Request) (*UserContext, error) } // NewKeyStoreAuthenticator creates a KeyStoreAuthenticator. @@ -32,21 +33,42 @@ func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*Login return nil, fmt.Errorf("keystore authenticator does not support login") } +// LoginWithCookie is not supported for keystore authentication. +func (a *KeyStoreAuthenticator) LoginWithCookie(_ context.Context, _ LoginRequest, _ http.ResponseWriter) (*LoginResponse, error) { + return nil, fmt.Errorf("keystore authenticator does not support login") +} + // Logout is not supported for keystore authentication. func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error { return nil } +// LogoutWithCookie is not supported for keystore authentication. +func (a *KeyStoreAuthenticator) LogoutWithCookie(_ context.Context, _ LogoutRequest, _ http.ResponseWriter) error { + return nil +} + +// SetAuthenticateCallback registers a fallback called when key authentication fails. +func (a *KeyStoreAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) { + a.authenticateCallback = fn +} + // Authenticate extracts an API key from the request and validates it against the KeyStore. // Returns a UserContext built from the matching UserKey on success. func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { rawKey := extractAPIKey(r) if rawKey == "" { + if a.authenticateCallback != nil { + return a.authenticateCallback(r) + } return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey or X-API-Key header)") } userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType) if err != nil { + if a.authenticateCallback != nil { + return a.authenticateCallback(r) + } return nil, fmt.Errorf("invalid API key: %w", err) } diff --git a/pkg/security/provider_test.go b/pkg/security/provider_test.go index 9c36beb..d3d535a 100644 --- a/pkg/security/provider_test.go +++ b/pkg/security/provider_test.go @@ -38,6 +38,8 @@ func (m *mockSecurityProvider) Authenticate(r *http.Request) (*UserContext, erro return m.authUser, m.authError } +func (m *mockSecurityProvider) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {} + func (m *mockSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { return m.columnSecurity, nil } diff --git a/pkg/security/providers.go b/pkg/security/providers.go index b634dbb..385d16d 100644 --- a/pkg/security/providers.go +++ b/pkg/security/providers.go @@ -88,6 +88,9 @@ type DatabaseAuthenticator struct { // Passkey provider (optional) passkeyProvider PasskeyProvider + + // Optional fallback called when primary authentication fails + authenticateCallback func(r *http.Request) (*UserContext, error) } // DatabaseAuthenticatorOptions configures the database authenticator @@ -113,6 +116,10 @@ type DatabaseAuthenticatorOptions struct { // CookieOptions configures the session cookie written by LoginWithCookie. // Only used when EnableCookieSession is true. CookieOptions SessionCookieOptions + // AuthenticateCallback is a fallback called when the primary authentication (database + // session lookup) fails. If non-nil and the callback returns a non-nil UserContext, + // that result is used in place of the failure. + AuthenticateCallback func(r *http.Request) (*UserContext, error) } func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { @@ -134,14 +141,15 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) return &DatabaseAuthenticator{ - db: db, - dbFactory: opts.DBFactory, - cache: cacheInstance, - cacheTTL: opts.CacheTTL, - sqlNames: sqlNames, - passkeyProvider: opts.PasskeyProvider, - enableCookieSession: opts.EnableCookieSession, - cookieOptions: opts.CookieOptions, + db: db, + dbFactory: opts.DBFactory, + cache: cacheInstance, + cacheTTL: opts.CacheTTL, + sqlNames: sqlNames, + passkeyProvider: opts.PasskeyProvider, + enableCookieSession: opts.EnableCookieSession, + cookieOptions: opts.CookieOptions, + authenticateCallback: opts.AuthenticateCallback, } } @@ -181,6 +189,10 @@ func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) er return err } +func (a *DatabaseAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) { + a.authenticateCallback = fn +} + func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { // Convert LoginRequest to JSON reqJSON, err := json.Marshal(req) @@ -345,6 +357,9 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err } if len(tokens) == 0 { + if a.authenticateCallback != nil { + return a.authenticateCallback(r) + } return nil, fmt.Errorf("session token required") } @@ -407,7 +422,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err return &userCtx, nil } - // All tokens failed + // All tokens failed — try callback before returning error + if a.authenticateCallback != nil { + return a.authenticateCallback(r) + } if lastErr != nil { return nil, lastErr } diff --git a/pkg/security/totp_integration_test.go b/pkg/security/totp_integration_test.go index 189e277..88eb3c0 100644 --- a/pkg/security/totp_integration_test.go +++ b/pkg/security/totp_integration_test.go @@ -59,6 +59,8 @@ func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext return m.users["testuser"], nil } +func (m *MockAuthenticator) SetAuthenticateCallback(_ func(r *http.Request) (*security.UserContext, error)) {} + func TestTwoFactorAuthenticator_Setup(t *testing.T) { baseAuth := NewMockAuthenticator() provider := security.NewMemoryTwoFactorProvider(nil)