Multi Token warning and handling

This commit is contained in:
Hein 2025-12-10 08:44:37 +02:00
parent e1abd5ebc1
commit a95c28a0bf
2 changed files with 168 additions and 48 deletions

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/cache" "github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// Production-Ready Authenticators // Production-Ready Authenticators
@ -169,69 +170,98 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
// Extract session token from header or cookie // Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization") sessionToken := r.Header.Get("Authorization")
reference := "authenticate" reference := "authenticate"
var tokens []string
if sessionToken == "" { if sessionToken == "" {
// Try cookie // Try cookie
cookie, err := r.Cookie("session_token") cookie, err := r.Cookie("session_token")
if err == nil { if err == nil {
sessionToken = cookie.Value tokens = []string{cookie.Value}
reference = "cookie" reference = "cookie"
} }
} else { } else {
// Remove "Bearer " prefix if present // Parse Authorization header which may contain multiple comma-separated tokens
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ") // Format: "Token abc, Token def" or "Bearer abc" or just "abc"
// Remove "Token " prefix if present rawTokens := strings.Split(sessionToken, ",")
sessionToken = strings.TrimPrefix(sessionToken, "Token ") for _, token := range rawTokens {
token = strings.TrimSpace(token)
// Remove "Bearer " prefix if present
token = strings.TrimPrefix(token, "Bearer ")
// Remove "Token " prefix if present
token = strings.TrimPrefix(token, "Token ")
token = strings.TrimSpace(token)
if token != "" {
tokens = append(tokens, token)
}
}
} }
if sessionToken == "" { if len(tokens) == 0 {
return nil, fmt.Errorf("session token required") return nil, fmt.Errorf("session token required")
} }
// Build cache key // Log warning if multiple tokens are provided
cacheKey := fmt.Sprintf("auth:session:%s", sessionToken) if len(tokens) > 1 {
logger.Warn("Multiple authentication tokens provided in Authorization header (%d tokens). This is unusual and may indicate a misconfigured client. Header: %s", len(tokens), sessionToken)
// Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (interface{}, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
return nil, err
} }
// Update last activity timestamp asynchronously // Try each token until one succeeds
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx) var lastErr error
for _, token := range tokens {
// Build cache key
cacheKey := fmt.Sprintf("auth:session:%s", token)
return &userCtx, nil // Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (any, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
lastErr = err
continue // Try next token
}
// Authentication succeeded with this token
// Update last activity timestamp asynchronously
go a.updateSessionActivity(r.Context(), token, &userCtx)
return &userCtx, nil
}
// All tokens failed
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("authentication failed for all provided tokens")
} }
// ClearCache removes a specific token from the cache or clears all cache if token is empty // ClearCache removes a specific token from the cache or clears all cache if token is empty

View File

@ -545,6 +545,96 @@ func TestDatabaseAuthenticator(t *testing.T) {
t.Fatal("expected error when token is missing") t.Fatal("expected error when token is missing")
} }
}) })
t.Run("authenticate with multiple comma-separated tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token invalid-token, Token valid-token-123")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("invalid-token", "authenticate").
WillReturnRows(rows1)
// Second token succeeds
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":3,"user_name":"multitoken","session_id":"valid-token-123"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("valid-token-123", "authenticate").
WillReturnRows(rows2)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 3 {
t.Errorf("expected UserID 3, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with duplicate tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A, Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A")
// First token succeeds
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":4,"user_name":"duplicateuser","session_id":"968CA5AE-4F83-4D55-A3C6-51AE4410E03A"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("968CA5AE-4F83-4D55-A3C6-51AE4410E03A", "authenticate").
WillReturnRows(rows)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 4 {
t.Errorf("expected UserID 4, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with all tokens failing", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token bad-token-1, Token bad-token-2")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-1", "authenticate").
WillReturnRows(rows1)
// Second token also fails
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-2", "authenticate").
WillReturnRows(rows2)
_, err := auth.Authenticate(req)
if err == nil {
t.Fatal("expected error when all tokens fail")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
} }
// Test DatabaseAuthenticator RefreshToken // Test DatabaseAuthenticator RefreshToken