mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-05-31 03:03:44 +00:00
fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed. * Implemented mutex locks for safe concurrent access to the database connection. * Updated all database query methods to handle reconnection attempts on closed connections. * Enhanced error handling for database operations across multiple providers.
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
@@ -22,6 +23,9 @@ type DatabaseKeyStoreOptions struct {
|
||||
CacheTTL time.Duration
|
||||
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||
SQLNames *KeyStoreSQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||
@@ -34,10 +38,12 @@ type DatabaseKeyStoreOptions struct {
|
||||
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||
type DatabaseKeyStore struct {
|
||||
db *sql.DB
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||
@@ -55,13 +61,34 @@ func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseK
|
||||
}
|
||||
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||
return &DatabaseKeyStore{
|
||||
db: db,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
db: db,
|
||||
dbFactory: o.DBFactory,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||
ks.dbMu.RLock()
|
||||
defer ks.dbMu.RUnlock()
|
||||
return ks.db
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||
if ks.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := ks.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ks.dbMu.Lock()
|
||||
ks.db = newDB
|
||||
ks.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||
// and returns the raw key once.
|
||||
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
@@ -100,7 +127,7 @@ func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest)
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||
if err = ks.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -123,7 +150,7 @@ func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType
|
||||
var keysJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||
if err := ks.db.QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -151,7 +178,7 @@ func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int
|
||||
var keyHash sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||
if err := ks.db.QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -184,9 +211,21 @@ func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyT
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
if err := ks.db.QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||
}
|
||||
if err := runQuery(); err != nil {
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||
|
||||
Reference in New Issue
Block a user