feat(auth): add authenticate callback for fallback logic

* Implement SetAuthenticateCallback in authenticators
* Update Authenticate methods to use callback on failure
This commit is contained in:
Hein
2026-05-21 11:27:51 +02:00
parent e5984f5205
commit 0308644075
9 changed files with 85 additions and 12 deletions

View File

@@ -9,7 +9,8 @@ import (
// ChainAuthenticator tries each authenticator in order, returning the first success.
// Login and Logout are delegated to the primary authenticator.
type ChainAuthenticator struct {
authenticators []Authenticator
authenticators []Authenticator
authenticateCallback func(r *http.Request) (*UserContext, error)
}
// NewChainAuthenticator creates a ChainAuthenticator from the given authenticators.
@@ -29,13 +30,28 @@ func (c *ChainAuthenticator) Authenticate(r *http.Request) (*UserContext, error)
lastErr = err
}
}
if c.authenticateCallback != nil {
return c.authenticateCallback(r)
}
return nil, fmt.Errorf("all authenticators failed; last error: %w", lastErr)
}
func (c *ChainAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
c.authenticateCallback = fn
}
func (c *ChainAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
return c.authenticators[0].Login(ctx, req)
}
func (c *ChainAuthenticator) LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error) {
return c.authenticators[0].LoginWithCookie(ctx, req, w)
}
func (c *ChainAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
return c.authenticators[0].Logout(ctx, req)
}
func (c *ChainAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error {
return c.authenticators[0].LogoutWithCookie(ctx, req, w)
}

View File

@@ -37,6 +37,8 @@ func (s *stubAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequ
return s.Logout(ctx, req)
}
func (s *stubAuthenticator) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {}
func TestChainAuthenticator_Authenticate(t *testing.T) {
successCtx := &UserContext{UserID: 42, UserName: "alice"}
failStub := &stubAuthenticator{err: fmt.Errorf("no token")}

View File

@@ -63,6 +63,11 @@ func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext,
return c.auth.Authenticate(r)
}
// SetAuthenticateCallback delegates to the authenticator
func (c *CompositeSecurityProvider) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
c.auth.SetAuthenticateCallback(fn)
}
// GetColumnSecurity delegates to the column security provider
func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
return c.colSec.GetColumnSecurity(ctx, userID, schema, table)

View File

@@ -39,6 +39,8 @@ func (m *mockAuth) Authenticate(r *http.Request) (*UserContext, error) {
return m.authUser, m.authErr
}
func (m *mockAuth) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {}
// Optional interface implementations
func (m *mockAuth) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
if !m.supportsRefresh {

View File

@@ -99,6 +99,10 @@ type Authenticator interface {
// Authenticate extracts and validates user from HTTP request
// Returns UserContext or error if authentication fails
Authenticate(r *http.Request) (*UserContext, error)
// SetAuthenticateCallback registers a fallback called when primary authentication fails.
// If the callback returns a non-nil UserContext, that result is used instead of the error.
SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error))
}
// Registrable allows providers to support user registration

View File

@@ -17,8 +17,9 @@ import (
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
type KeyStoreAuthenticator struct {
keyStore KeyStore
keyType KeyType // empty = accept any type
keyStore KeyStore
keyType KeyType // empty = accept any type
authenticateCallback func(r *http.Request) (*UserContext, error)
}
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
@@ -32,21 +33,42 @@ func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*Login
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// LoginWithCookie is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) LoginWithCookie(_ context.Context, _ LoginRequest, _ http.ResponseWriter) (*LoginResponse, error) {
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// Logout is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return nil
}
// LogoutWithCookie is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) LogoutWithCookie(_ context.Context, _ LogoutRequest, _ http.ResponseWriter) error {
return nil
}
// SetAuthenticateCallback registers a fallback called when key authentication fails.
func (a *KeyStoreAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
a.authenticateCallback = fn
}
// Authenticate extracts an API key from the request and validates it against the KeyStore.
// Returns a UserContext built from the matching UserKey on success.
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
rawKey := extractAPIKey(r)
if rawKey == "" {
if a.authenticateCallback != nil {
return a.authenticateCallback(r)
}
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
}
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
if err != nil {
if a.authenticateCallback != nil {
return a.authenticateCallback(r)
}
return nil, fmt.Errorf("invalid API key: %w", err)
}

View File

@@ -38,6 +38,8 @@ func (m *mockSecurityProvider) Authenticate(r *http.Request) (*UserContext, erro
return m.authUser, m.authError
}
func (m *mockSecurityProvider) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {}
func (m *mockSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
return m.columnSecurity, nil
}

View File

@@ -88,6 +88,9 @@ type DatabaseAuthenticator struct {
// Passkey provider (optional)
passkeyProvider PasskeyProvider
// Optional fallback called when primary authentication fails
authenticateCallback func(r *http.Request) (*UserContext, error)
}
// DatabaseAuthenticatorOptions configures the database authenticator
@@ -113,6 +116,10 @@ type DatabaseAuthenticatorOptions struct {
// CookieOptions configures the session cookie written by LoginWithCookie.
// Only used when EnableCookieSession is true.
CookieOptions SessionCookieOptions
// AuthenticateCallback is a fallback called when the primary authentication (database
// session lookup) fails. If non-nil and the callback returns a non-nil UserContext,
// that result is used in place of the failure.
AuthenticateCallback func(r *http.Request) (*UserContext, error)
}
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -134,14 +141,15 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{
db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance,
cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider,
enableCookieSession: opts.EnableCookieSession,
cookieOptions: opts.CookieOptions,
db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance,
cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider,
enableCookieSession: opts.EnableCookieSession,
cookieOptions: opts.CookieOptions,
authenticateCallback: opts.AuthenticateCallback,
}
}
@@ -181,6 +189,10 @@ func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) er
return err
}
func (a *DatabaseAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
a.authenticateCallback = fn
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
@@ -345,6 +357,9 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
}
if len(tokens) == 0 {
if a.authenticateCallback != nil {
return a.authenticateCallback(r)
}
return nil, fmt.Errorf("session token required")
}
@@ -407,7 +422,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
return &userCtx, nil
}
// All tokens failed
// All tokens failed — try callback before returning error
if a.authenticateCallback != nil {
return a.authenticateCallback(r)
}
if lastErr != nil {
return nil, lastErr
}

View File

@@ -59,6 +59,8 @@ func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext
return m.users["testuser"], nil
}
func (m *MockAuthenticator) SetAuthenticateCallback(_ func(r *http.Request) (*security.UserContext, error)) {}
func TestTwoFactorAuthenticator_Setup(t *testing.T) {
baseAuth := NewMockAuthenticator()
provider := security.NewMemoryTwoFactorProvider(nil)