fix(db): improve database connection handling and reconnection logic

* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
This commit is contained in:
Hein
2026-04-09 09:19:28 +02:00
parent a9bf08f58b
commit 79a3912f93
10 changed files with 449 additions and 91 deletions
+3 -2
View File
@@ -81,7 +81,8 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT
defer s.mu.RUnlock()
var result []UserKey
for _, k := range s.keys {
for i := range s.keys {
k := &s.keys[i]
if k.UserID != userID || !k.IsActive {
continue
}
@@ -91,7 +92,7 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT
if keyType != "" && k.KeyType != keyType {
continue
}
result = append(result, k)
result = append(result, *k)
}
return result, nil
}
+53 -14
View File
@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
@@ -22,6 +23,9 @@ type DatabaseKeyStoreOptions struct {
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
@@ -34,10 +38,12 @@ type DatabaseKeyStoreOptions struct {
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
@@ -55,13 +61,34 @@ func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseK
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
@@ -100,7 +127,7 @@ func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest)
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
@@ -123,7 +150,7 @@ func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.db.QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
@@ -151,7 +178,7 @@ func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.db.QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
@@ -184,9 +211,21 @@ func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyT
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
if err := ks.db.QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
+5 -5
View File
@@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string
var userID *int
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id
FROM %s($1::jsonb)
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
@@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
@@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string
var sessionData []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
@@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool
var updateErrMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
@@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string
var userData []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
+55 -20
View File
@@ -7,18 +7,21 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"sync"
"time"
)
// DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct {
db *sql.DB
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
}
// DatabasePasskeyProviderOptions configures the passkey provider
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{
db: db,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
sqlNames: sqlNames,
db: db,
dbFactory: opts.DBFactory,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
sqlNames: sqlNames,
}
}
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabasePasskeyProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// BeginRegistration creates registration options for a new passkey
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
// Generate challenge
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var credentialID sql.NullInt64
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err)
}
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var credentialsJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString
var credentialJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err)
}
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var cloneWarning sql.NullBool
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err)
}
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var credentialsJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
}
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
}
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to update credential name: %w", err)
}
+178 -26
View File
@@ -63,10 +63,12 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey()
type DatabaseAuthenticator struct {
db *sql.DB
cache *cache.Cache
cacheTTL time.Duration
sqlNames *SQLNames
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
cache *cache.Cache
cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider
@@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct {
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
return &DatabaseAuthenticator{
db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance,
cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
@@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
}
}
func (a *DatabaseAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *DatabaseAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
@@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
runLoginQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
}
@@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var userJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
@@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
var updatedUserJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
}
// RefreshToken implements Refreshable interface
@@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var userJSON sql.NullString
// Get current session to pass to refresh
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
}
@@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var newUserJSON sql.NullString
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
}
@@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
type JWTAuthenticator struct {
secretKey []byte
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
@@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA
}
}
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
a.dbFactory = factory
return a
}
func (a *JWTAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *JWTAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
var success bool
var errorMsg sql.NullString
var userJSON []byte
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
runLoginQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
}
err := runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error
var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -513,14 +583,41 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// All database operations go through stored procedures
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct {
db *sql.DB
sqlNames *SQLNames
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity
@@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
var errorMsg sql.NullString
var rulesJSON []byte
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err)
}
@@ -578,21 +683,55 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// All database operations go through stored procedures
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct {
db *sql.DB
sqlNames *SQLNames
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames
}
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string
var hasBlock bool
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil {
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
}
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
// Helper functions
// ================
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
func parseRoles(rolesStr string) []string {
if rolesStr == "" {
return []string{}
@@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
runPasskeyQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runPasskeyQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runPasskeyQuery()
}
}
if err != nil {
return nil, fmt.Errorf("passkey login query failed: %w", err)
}