mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-16 21:03:51 +00:00
fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed. * Implemented mutex locks for safe concurrent access to the database connection. * Updated all database query methods to handle reconnection attempts on closed connections. * Enhanced error handling for database operations across multiple providers.
This commit is contained in:
@@ -7,18 +7,21 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user