Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
c90c2984ac feat(security): add cookie session support to DatabaseAuthenticator
* Introduce enableCookieSession option for session management
* Implement LoginWithCookie and LogoutWithCookie methods
* Update Authenticate method to support session token from cookie
2026-05-21 09:14:50 +02:00
Hein
1ab4ae33e7 feat(security): implement ChainAuthenticator for sequential authentication 2026-05-21 08:35:39 +02:00
Hein
905457964c fix(restheadspec): remove redundant column selection in query 2026-05-21 08:34:09 +02:00
4 changed files with 211 additions and 11 deletions

View File

@@ -1367,7 +1367,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// First, read the existing record from the database // First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil { if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return fmt.Errorf("record not found with ID: %v", targetID) return fmt.Errorf("record not found with ID: %v", targetID)

41
pkg/security/chain.go Normal file
View File

@@ -0,0 +1,41 @@
package security
import (
"context"
"fmt"
"net/http"
)
// ChainAuthenticator tries each authenticator in order, returning the first success.
// Login and Logout are delegated to the primary authenticator.
type ChainAuthenticator struct {
authenticators []Authenticator
}
// NewChainAuthenticator creates a ChainAuthenticator from the given authenticators.
// At least one authenticator is required; the first is treated as primary for Login/Logout.
func NewChainAuthenticator(primary Authenticator, rest ...Authenticator) *ChainAuthenticator {
return &ChainAuthenticator{
authenticators: append([]Authenticator{primary}, rest...),
}
}
func (c *ChainAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
var lastErr error
for _, a := range c.authenticators {
if uc, err := a.Authenticate(r); err == nil {
return uc, nil
} else {
lastErr = err
}
}
return nil, fmt.Errorf("all authenticators failed; last error: %w", lastErr)
}
func (c *ChainAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
return c.authenticators[0].Login(ctx, req)
}
func (c *ChainAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
return c.authenticators[0].Logout(ctx, req)
}

117
pkg/security/chain_test.go Normal file
View File

@@ -0,0 +1,117 @@
package security
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
// stubAuthenticator is a configurable Authenticator for testing.
type stubAuthenticator struct {
userCtx *UserContext
err error
}
func (s *stubAuthenticator) Authenticate(_ *http.Request) (*UserContext, error) {
return s.userCtx, s.err
}
func (s *stubAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
if s.err != nil {
return nil, s.err
}
return &LoginResponse{Token: "tok"}, nil
}
func (s *stubAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return s.err
}
func TestChainAuthenticator_Authenticate(t *testing.T) {
successCtx := &UserContext{UserID: 42, UserName: "alice"}
failStub := &stubAuthenticator{err: fmt.Errorf("no token")}
okStub := &stubAuthenticator{userCtx: successCtx}
t.Run("primary succeeds", func(t *testing.T) {
chain := NewChainAuthenticator(okStub, failStub)
req := httptest.NewRequest("GET", "/", nil)
uc, err := chain.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if uc.UserID != 42 {
t.Errorf("expected UserID 42, got %d", uc.UserID)
}
})
t.Run("primary fails, secondary succeeds", func(t *testing.T) {
chain := NewChainAuthenticator(failStub, okStub)
req := httptest.NewRequest("GET", "/", nil)
uc, err := chain.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if uc.UserID != 42 {
t.Errorf("expected UserID 42, got %d", uc.UserID)
}
})
t.Run("all fail", func(t *testing.T) {
chain := NewChainAuthenticator(failStub, failStub)
req := httptest.NewRequest("GET", "/", nil)
_, err := chain.Authenticate(req)
if err == nil {
t.Fatal("expected error when all authenticators fail")
}
})
t.Run("three in chain, first two fail", func(t *testing.T) {
chain := NewChainAuthenticator(failStub, failStub, okStub)
req := httptest.NewRequest("GET", "/", nil)
uc, err := chain.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if uc.UserName != "alice" {
t.Errorf("expected UserName alice, got %s", uc.UserName)
}
})
}
func TestChainAuthenticator_LoginLogout(t *testing.T) {
primary := &stubAuthenticator{userCtx: &UserContext{UserID: 1}}
secondary := &stubAuthenticator{userCtx: &UserContext{UserID: 2}}
chain := NewChainAuthenticator(primary, secondary)
ctx := context.Background()
t.Run("login delegates to primary", func(t *testing.T) {
resp, err := chain.Login(ctx, LoginRequest{Username: "u", Password: "p"})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if resp.Token != "tok" {
t.Errorf("expected token from primary, got %s", resp.Token)
}
})
t.Run("logout delegates to primary", func(t *testing.T) {
if err := chain.Logout(ctx, LogoutRequest{}); err != nil {
t.Fatalf("expected no error, got %v", err)
}
})
t.Run("login error from primary is returned", func(t *testing.T) {
failPrimary := &stubAuthenticator{err: fmt.Errorf("db down")}
chain2 := NewChainAuthenticator(failPrimary, secondary)
_, err := chain2.Login(ctx, LoginRequest{})
if err == nil {
t.Fatal("expected error from primary login failure")
}
})
}

View File

@@ -70,6 +70,10 @@ type DatabaseAuthenticator struct {
cacheTTL time.Duration cacheTTL time.Duration
sqlNames *SQLNames sqlNames *SQLNames
// Cookie session support (optional, gated by enableCookieSession)
enableCookieSession bool
cookieOptions SessionCookieOptions
// OAuth2 providers registry (multiple providers supported) // OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider oauth2Providers map[string]*OAuth2Provider
oauth2ProvidersMutex sync.RWMutex oauth2ProvidersMutex sync.RWMutex
@@ -93,6 +97,14 @@ type DatabaseAuthenticatorOptions struct {
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled. // If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error) DBFactory func() (*sql.DB, error)
// EnableCookieSession enables cookie-based session management.
// When true, Authenticate reads the session token from the cookie named by
// CookieOptions.Name (default "session_token") in addition to the Authorization header,
// and LoginWithCookie / LogoutWithCookie automatically set / clear the cookie.
EnableCookieSession bool
// CookieOptions configures the session cookie written by LoginWithCookie.
// Only used when EnableCookieSession is true.
CookieOptions SessionCookieOptions
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -114,12 +126,14 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
dbFactory: opts.DBFactory, dbFactory: opts.DBFactory,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
sqlNames: sqlNames, sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider, passkeyProvider: opts.PasskeyProvider,
enableCookieSession: opts.EnableCookieSession,
cookieOptions: opts.CookieOptions,
} }
} }
@@ -265,6 +279,33 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return nil return nil
} }
// LoginWithCookie performs a login and, when EnableCookieSession is true, writes the
// session cookie to w using the configured CookieOptions. The LoginResponse is returned
// regardless of whether cookie sessions are enabled.
func (a *DatabaseAuthenticator) LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error) {
resp, err := a.Login(ctx, req)
if err != nil {
return nil, err
}
if a.enableCookieSession {
SetSessionCookie(w, resp, a.cookieOptions)
}
return resp, nil
}
// LogoutWithCookie performs a logout and, when EnableCookieSession is true, clears the
// session cookie on w. The logout itself is performed regardless of the cookie flag.
func (a *DatabaseAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error {
err := a.Logout(ctx, req)
if err != nil {
return err
}
if a.enableCookieSession {
ClearSessionCookie(w, a.cookieOptions)
}
return nil
}
func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// Extract session token from header or cookie // Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization") sessionToken := r.Header.Get("Authorization")
@@ -272,10 +313,11 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var tokens []string var tokens []string
if sessionToken == "" { if sessionToken == "" {
// Try cookie if a.enableCookieSession {
if token := GetSessionCookie(r); token != "" { if token := GetSessionCookie(r, a.cookieOptions); token != "" {
tokens = []string{token} tokens = []string{token}
reference = "cookie" reference = "cookie"
}
} }
} else { } else {
// Parse Authorization header which may contain multiple comma-separated tokens // Parse Authorization header which may contain multiple comma-separated tokens