mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
Multi Token warning and handling
This commit is contained in:
parent
e1abd5ebc1
commit
a95c28a0bf
@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Production-Ready Authenticators
|
||||
@ -169,69 +170,98 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
// Extract session token from header or cookie
|
||||
sessionToken := r.Header.Get("Authorization")
|
||||
reference := "authenticate"
|
||||
var tokens []string
|
||||
|
||||
if sessionToken == "" {
|
||||
// Try cookie
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err == nil {
|
||||
sessionToken = cookie.Value
|
||||
tokens = []string{cookie.Value}
|
||||
reference = "cookie"
|
||||
}
|
||||
} else {
|
||||
// Remove "Bearer " prefix if present
|
||||
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ")
|
||||
// Remove "Token " prefix if present
|
||||
sessionToken = strings.TrimPrefix(sessionToken, "Token ")
|
||||
// Parse Authorization header which may contain multiple comma-separated tokens
|
||||
// Format: "Token abc, Token def" or "Bearer abc" or just "abc"
|
||||
rawTokens := strings.Split(sessionToken, ",")
|
||||
for _, token := range rawTokens {
|
||||
token = strings.TrimSpace(token)
|
||||
// Remove "Bearer " prefix if present
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
// Remove "Token " prefix if present
|
||||
token = strings.TrimPrefix(token, "Token ")
|
||||
token = strings.TrimSpace(token)
|
||||
if token != "" {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sessionToken == "" {
|
||||
if len(tokens) == 0 {
|
||||
return nil, fmt.Errorf("session token required")
|
||||
}
|
||||
|
||||
// Build cache key
|
||||
cacheKey := fmt.Sprintf("auth:session:%s", sessionToken)
|
||||
|
||||
// Use cache.GetOrSet to get from cache or load from database
|
||||
var userCtx UserContext
|
||||
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (interface{}, error) {
|
||||
// This function is called only if cache miss
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired session")
|
||||
}
|
||||
|
||||
if !userJSON.Valid {
|
||||
return nil, fmt.Errorf("no user data in session")
|
||||
}
|
||||
|
||||
// Parse UserContext
|
||||
var user UserContext
|
||||
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Log warning if multiple tokens are provided
|
||||
if len(tokens) > 1 {
|
||||
logger.Warn("Multiple authentication tokens provided in Authorization header (%d tokens). This is unusual and may indicate a misconfigured client. Header: %s", len(tokens), sessionToken)
|
||||
}
|
||||
|
||||
// Update last activity timestamp asynchronously
|
||||
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx)
|
||||
// Try each token until one succeeds
|
||||
var lastErr error
|
||||
for _, token := range tokens {
|
||||
// Build cache key
|
||||
cacheKey := fmt.Sprintf("auth:session:%s", token)
|
||||
|
||||
return &userCtx, nil
|
||||
// Use cache.GetOrSet to get from cache or load from database
|
||||
var userCtx UserContext
|
||||
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (any, error) {
|
||||
// This function is called only if cache miss
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired session")
|
||||
}
|
||||
|
||||
if !userJSON.Valid {
|
||||
return nil, fmt.Errorf("no user data in session")
|
||||
}
|
||||
|
||||
// Parse UserContext
|
||||
var user UserContext
|
||||
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue // Try next token
|
||||
}
|
||||
|
||||
// Authentication succeeded with this token
|
||||
// Update last activity timestamp asynchronously
|
||||
go a.updateSessionActivity(r.Context(), token, &userCtx)
|
||||
|
||||
return &userCtx, nil
|
||||
}
|
||||
|
||||
// All tokens failed
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, fmt.Errorf("authentication failed for all provided tokens")
|
||||
}
|
||||
|
||||
// ClearCache removes a specific token from the cache or clears all cache if token is empty
|
||||
|
||||
@ -545,6 +545,96 @@ func TestDatabaseAuthenticator(t *testing.T) {
|
||||
t.Fatal("expected error when token is missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticate with multiple comma-separated tokens", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Token invalid-token, Token valid-token-123")
|
||||
|
||||
// First token fails
|
||||
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(false, "Invalid token", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("invalid-token", "authenticate").
|
||||
WillReturnRows(rows1)
|
||||
|
||||
// Second token succeeds
|
||||
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":3,"user_name":"multitoken","session_id":"valid-token-123"}`)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("valid-token-123", "authenticate").
|
||||
WillReturnRows(rows2)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if userCtx.UserID != 3 {
|
||||
t.Errorf("expected UserID 3, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticate with duplicate tokens", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A, Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A")
|
||||
|
||||
// First token succeeds
|
||||
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":4,"user_name":"duplicateuser","session_id":"968CA5AE-4F83-4D55-A3C6-51AE4410E03A"}`)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("968CA5AE-4F83-4D55-A3C6-51AE4410E03A", "authenticate").
|
||||
WillReturnRows(rows)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if userCtx.UserID != 4 {
|
||||
t.Errorf("expected UserID 4, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticate with all tokens failing", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Token bad-token-1, Token bad-token-2")
|
||||
|
||||
// First token fails
|
||||
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(false, "Invalid token", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("bad-token-1", "authenticate").
|
||||
WillReturnRows(rows1)
|
||||
|
||||
// Second token also fails
|
||||
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(false, "Invalid token", nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("bad-token-2", "authenticate").
|
||||
WillReturnRows(rows2)
|
||||
|
||||
_, err := auth.Authenticate(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all tokens fail")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled expectations: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test DatabaseAuthenticator RefreshToken
|
||||
|
||||
Loading…
Reference in New Issue
Block a user