fix(db): improve database connection handling and reconnection logic

* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
This commit is contained in:
Hein
2026-04-09 09:19:28 +02:00
parent a9bf08f58b
commit 79a3912f93
10 changed files with 449 additions and 91 deletions
+53 -14
View File
@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
@@ -22,6 +23,9 @@ type DatabaseKeyStoreOptions struct {
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
@@ -34,10 +38,12 @@ type DatabaseKeyStoreOptions struct {
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
@@ -55,13 +61,34 @@ func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseK
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
@@ -100,7 +127,7 @@ func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest)
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
@@ -123,7 +150,7 @@ func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.db.QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
@@ -151,7 +178,7 @@ func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.db.QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
@@ -184,9 +211,21 @@ func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyT
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
if err := ks.db.QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))