From 79a3912f9317115fe1c46a24dfaa2ba5a1298963 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 9 Apr 2026 09:19:28 +0200 Subject: [PATCH] fix(db): improve database connection handling and reconnection logic * Added a database factory function to allow reconnection when the database is closed. * Implemented mutex locks for safe concurrent access to the database connection. * Updated all database query methods to handle reconnection attempts on closed connections. * Enhanced error handling for database operations across multiple providers. --- pkg/common/adapters/database/bun.go | 62 ++++++-- pkg/common/adapters/database/pgsql.go | 63 +++++++- pkg/dbmanager/providers/provider.go | 6 + pkg/dbmanager/providers/sqlite.go | 41 +++++- pkg/resolvemcp/handler.go | 7 +- pkg/security/keystore_config.go | 5 +- pkg/security/keystore_database.go | 67 +++++++-- pkg/security/oauth2_methods.go | 10 +- pkg/security/passkey_provider.go | 75 +++++++--- pkg/security/providers.go | 204 ++++++++++++++++++++++---- 10 files changed, 449 insertions(+), 91 deletions(-) diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index d36337b..458db47 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" "github.com/uptrace/bun" @@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error { // This demonstrates how the abstraction works with different ORMs type BunAdapter struct { db *bun.DB + dbMu sync.RWMutex + dbFactory func() (*bun.DB, error) driverName string } @@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter { return adapter } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter { + b.dbFactory = factory + return b +} + +func (b *BunAdapter) getDB() *bun.DB { + b.dbMu.RLock() + defer b.dbMu.RUnlock() + return b.db +} + +func (b *BunAdapter) reconnectDB() error { + if b.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := b.dbFactory() + if err != nil { + return err + } + b.dbMu.Lock() + b.db = newDB + b.dbMu.Unlock() + return nil +} + // EnableQueryDebug enables query debugging which logs all SQL queries including preloads // This is useful for debugging preload queries that may be failing func (b *BunAdapter) EnableQueryDebug() { - b.db.AddQueryHook(&QueryDebugHook{}) + b.getDB().AddQueryHook(&QueryDebugHook{}) logger.Info("Bun query debug mode enabled - all SQL queries will be logged") } @@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() { func (b *BunAdapter) NewSelect() common.SelectQuery { return &BunSelectQuery{ - query: b.db.NewSelect(), + query: b.getDB().NewSelect(), db: b.db, driverName: b.driverName, } } func (b *BunAdapter) NewInsert() common.InsertQuery { - return &BunInsertQuery{query: b.db.NewInsert()} + return &BunInsertQuery{query: b.getDB().NewInsert()} } func (b *BunAdapter) NewUpdate() common.UpdateQuery { - return &BunUpdateQuery{query: b.db.NewUpdate()} + return &BunUpdateQuery{query: b.getDB().NewUpdate()} } func (b *BunAdapter) NewDelete() common.DeleteQuery { - return &BunDeleteQuery{query: b.db.NewDelete()} + return &BunDeleteQuery{query: b.getDB().NewDelete()} } func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { @@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{} err = logger.HandlePanic("BunAdapter.Exec", r) } }() - result, err := b.db.ExecContext(ctx, query, args...) + var result sql.Result + run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := b.reconnectDB(); reconnErr == nil { + err = run() + } + } return &BunResult{result: result}, err } @@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, err = logger.HandlePanic("BunAdapter.Query", r) } }() - return b.db.NewRaw(query, args...).Scan(ctx, dest) + err = b.getDB().NewRaw(query, args...).Scan(ctx, dest) + if isDBClosed(err) { + if reconnErr := b.reconnectDB(); reconnErr == nil { + err = b.getDB().NewRaw(query, args...).Scan(ctx, dest) + } + } + return err } func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { - tx, err := b.db.BeginTx(ctx, &sql.TxOptions{}) + tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err } @@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa err = logger.HandlePanic("BunAdapter.RunInTransaction", r) } }() - return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { // Create adapter with transaction adapter := &BunTxAdapter{tx: tx, driverName: b.driverName} return fn(adapter) @@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa } func (b *BunAdapter) GetUnderlyingDB() interface{} { - return b.db + return b.getDB() } func (b *BunAdapter) DriverName() string { diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go index 87f3631..336e2af 100644 --- a/pkg/common/adapters/database/pgsql.go +++ b/pkg/common/adapters/database/pgsql.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" @@ -17,6 +18,8 @@ import ( // This provides a lightweight PostgreSQL adapter without ORM overhead type PgSQLAdapter struct { db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) driverName string } @@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter { return &PgSQLAdapter{db: db, driverName: name} } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter { + p.dbFactory = factory + return p +} + +func (p *PgSQLAdapter) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *PgSQLAdapter) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + // EnableQueryDebug enables query debugging for development func (p *PgSQLAdapter) EnableQueryDebug() { logger.Info("PgSQL query debug mode - logging enabled via logger") @@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() { func (p *PgSQLAdapter) NewSelect() common.SelectQuery { return &PgSQLSelectQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, columns: []string{"*"}, args: make([]interface{}, 0), @@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery { func (p *PgSQLAdapter) NewInsert() common.InsertQuery { return &PgSQLInsertQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, values: make(map[string]interface{}), } @@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery { func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { return &PgSQLUpdateQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, sets: make(map[string]interface{}), args: make([]interface{}, 0), @@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { return &PgSQLDeleteQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, args: make([]interface{}, 0), whereClauses: make([]string, 0), @@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface } }() logger.Debug("PgSQL Exec: %s [args: %v]", query, args) - result, err := p.db.ExecContext(ctx, query, args...) + var result sql.Result + run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { logger.Error("PgSQL Exec failed: %v", err) return nil, err @@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string } }() logger.Debug("PgSQL Query: %s [args: %v]", query, args) - rows, err := p.db.QueryContext(ctx, query, args...) + var rows *sql.Rows + run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { logger.Error("PgSQL Query failed: %v", err) return err @@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string } func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { - tx, err := p.db.BeginTx(ctx, nil) + tx, err := p.getDB().BeginTx(ctx, nil) if err != nil { return nil, err } @@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data } }() - tx, err := p.db.BeginTx(ctx, nil) + tx, err := p.getDB().BeginTx(ctx, nil) if err != nil { return err } diff --git a/pkg/dbmanager/providers/provider.go b/pkg/dbmanager/providers/provider.go index 65dbc6f..a541f2b 100644 --- a/pkg/dbmanager/providers/provider.go +++ b/pkg/dbmanager/providers/provider.go @@ -4,11 +4,17 @@ import ( "context" "database/sql" "errors" + "strings" "time" "go.mongodb.org/mongo-driver/mongo" ) +// isDBClosed reports whether err indicates the *sql.DB has been closed. +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + // Common errors var ( // ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database diff --git a/pkg/dbmanager/providers/sqlite.go b/pkg/dbmanager/providers/sqlite.go index 6f70970..4306b8d 100644 --- a/pkg/dbmanager/providers/sqlite.go +++ b/pkg/dbmanager/providers/sqlite.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sync" "time" _ "github.com/glebarez/sqlite" // Pure Go SQLite driver @@ -14,8 +15,10 @@ import ( // SQLiteProvider implements Provider for SQLite databases type SQLiteProvider struct { - db *sql.DB - config ConnectionConfig + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + config ConnectionConfig } // NewSQLiteProvider creates a new SQLite provider @@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error { // Execute a simple query to verify the database is accessible var result int - err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result) + run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) } + err := run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { return fmt.Errorf("health check failed: %w", err) } @@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error { return nil } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider { + p.dbFactory = factory + return p +} + +func (p *SQLiteProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *SQLiteProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + // GetNative returns the native *sql.DB connection func (p *SQLiteProvider) GetNative() (*sql.DB, error) { if p.db == nil { diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go index cf52414..73b270b 100644 --- a/pkg/resolvemcp/handler.go +++ b/pkg/resolvemcp/handler.go @@ -74,7 +74,7 @@ func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer { return server.NewSSEServer( h.mcpServer, server.WithBaseURL(baseURL), - server.WithBasePath(basePath), + server.WithStaticBasePath(basePath), ) } @@ -695,7 +695,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi return query.Where("("+strings.Join(conditions, " OR ")+")", args...) } -func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) { +func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) { switch filter.Operator { case "eq", "=": return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value} @@ -725,7 +725,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in } func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { - for _, preload := range preloads { + for i := range preloads { + preload := &preloads[i] if preload.Relation == "" { continue } diff --git a/pkg/security/keystore_config.go b/pkg/security/keystore_config.go index b42ccd4..353cf14 100644 --- a/pkg/security/keystore_config.go +++ b/pkg/security/keystore_config.go @@ -81,7 +81,8 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT defer s.mu.RUnlock() var result []UserKey - for _, k := range s.keys { + for i := range s.keys { + k := &s.keys[i] if k.UserID != userID || !k.IsActive { continue } @@ -91,7 +92,7 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT if keyType != "" && k.KeyType != keyType { continue } - result = append(result, k) + result = append(result, *k) } return result, nil } diff --git a/pkg/security/keystore_database.go b/pkg/security/keystore_database.go index 2960708..75e7eb1 100644 --- a/pkg/security/keystore_database.go +++ b/pkg/security/keystore_database.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "sync" "time" "github.com/bitechdev/ResolveSpec/pkg/cache" @@ -22,6 +23,9 @@ type DatabaseKeyStoreOptions struct { CacheTTL time.Duration // SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames(). SQLNames *KeyStoreSQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) } // DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures. @@ -34,10 +38,12 @@ type DatabaseKeyStoreOptions struct { // cache TTL, a deleted key may continue to authenticate for up to CacheTTL // (default 2 minutes) if the cache entry cannot be invalidated. type DatabaseKeyStore struct { - db *sql.DB - sqlNames *KeyStoreSQLNames - cache *cache.Cache - cacheTTL time.Duration + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *KeyStoreSQLNames + cache *cache.Cache + cacheTTL time.Duration } // NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration. @@ -55,13 +61,34 @@ func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseK } names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames) return &DatabaseKeyStore{ - db: db, - sqlNames: names, - cache: c, - cacheTTL: o.CacheTTL, + db: db, + dbFactory: o.DBFactory, + sqlNames: names, + cache: c, + cacheTTL: o.CacheTTL, } } +func (ks *DatabaseKeyStore) getDB() *sql.DB { + ks.dbMu.RLock() + defer ks.dbMu.RUnlock() + return ks.db +} + +func (ks *DatabaseKeyStore) reconnectDB() error { + if ks.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := ks.dbFactory() + if err != nil { + return err + } + ks.dbMu.Lock() + ks.db = newDB + ks.dbMu.Unlock() + return nil +} + // CreateKey generates a raw key, stores its SHA-256 hash via the create procedure, // and returns the raw key once. func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) { @@ -100,7 +127,7 @@ func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) var keyJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey) - if err = ks.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil { + if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil { return nil, fmt.Errorf("create key procedure failed: %w", err) } if !success { @@ -123,7 +150,7 @@ func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType var keysJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys) - if err := ks.db.QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil { + if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil { return nil, fmt.Errorf("get user keys procedure failed: %w", err) } if !success { @@ -151,7 +178,7 @@ func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int var keyHash sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey) - if err := ks.db.QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil { + if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil { return fmt.Errorf("delete key procedure failed: %w", err) } if !success { @@ -184,9 +211,21 @@ func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyT var errorMsg sql.NullString var keyJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey) - if err := ks.db.QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON); err != nil { - return nil, fmt.Errorf("validate key procedure failed: %w", err) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey) + return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON) + } + if err := runQuery(); err != nil { + if isDBClosed(err) { + if reconnErr := ks.reconnectDB(); reconnErr == nil { + err = runQuery() + } + if err != nil { + return nil, fmt.Errorf("validate key procedure failed: %w", err) + } + } else { + return nil, fmt.Errorf("validate key procedure failed: %w", err) + } } if !success { return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key")) diff --git a/pkg/security/oauth2_methods.go b/pkg/security/oauth2_methods.go index 0b87b3b..79a58de 100644 --- a/pkg/security/oauth2_methods.go +++ b/pkg/security/oauth2_methods.go @@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC var errMsg *string var userID *int - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_user_id FROM %s($1::jsonb) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID) @@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session var success bool var errMsg *string - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error FROM %s($1::jsonb) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg) @@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var errMsg *string var sessionData []byte - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text FROM %s($1) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData) @@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var updateSuccess bool var updateErrMsg *string - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error FROM %s($1::jsonb) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg) @@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var userErrMsg *string var userData []byte - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text FROM %s($1) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData) diff --git a/pkg/security/passkey_provider.go b/pkg/security/passkey_provider.go index 7abeab0..fb153c1 100644 --- a/pkg/security/passkey_provider.go +++ b/pkg/security/passkey_provider.go @@ -7,18 +7,21 @@ import ( "encoding/base64" "encoding/json" "fmt" + "sync" "time" ) // DatabasePasskeyProvider implements PasskeyProvider using database storage // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabasePasskeyProvider struct { - db *sql.DB - rpID string // Relying Party ID (domain) - rpName string // Relying Party display name - rpOrigin string // Expected origin for WebAuthn - timeout int64 // Timeout in milliseconds (default: 60000) - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + rpID string // Relying Party ID (domain) + rpName string // Relying Party display name + rpOrigin string // Expected origin for WebAuthn + timeout int64 // Timeout in milliseconds (default: 60000) + sqlNames *SQLNames } // DatabasePasskeyProviderOptions configures the passkey provider @@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct { Timeout int64 // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). SQLNames *SQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) } // NewDatabasePasskeyProvider creates a new database-backed passkey provider @@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) return &DatabasePasskeyProvider{ - db: db, - rpID: opts.RPID, - rpName: opts.RPName, - rpOrigin: opts.RPOrigin, - timeout: opts.Timeout, - sqlNames: sqlNames, + db: db, + dbFactory: opts.DBFactory, + rpID: opts.RPID, + rpName: opts.RPName, + rpOrigin: opts.RPOrigin, + timeout: opts.Timeout, + sqlNames: sqlNames, } } +func (p *DatabasePasskeyProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabasePasskeyProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + // BeginRegistration creates registration options for a new passkey func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) { // Generate challenge @@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user var credentialID sql.NullInt64 query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential) - err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) + err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) if err != nil { return nil, fmt.Errorf("failed to store credential: %w", err) } @@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern var credentialsJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername) - err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) + err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) } @@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var errorMsg sql.NullString var credentialJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) - err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) + return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return 0, fmt.Errorf("failed to get credential: %w", err) } @@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var cloneWarning sql.NullBool updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter) - err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) + err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) if err != nil { return 0, fmt.Errorf("failed to update counter: %w", err) } @@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int var credentialsJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials) - err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) + err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) } @@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential) - err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) + err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to delete credential: %w", err) } @@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName) - err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) + err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to update credential name: %w", err) } diff --git a/pkg/security/providers.go b/pkg/security/providers.go index e23bd82..af57172 100644 --- a/pkg/security/providers.go +++ b/pkg/security/providers.go @@ -63,10 +63,12 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error // Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports passkey authentication configured with WithPasskey() type DatabaseAuthenticator struct { - db *sql.DB - cache *cache.Cache - cacheTTL time.Duration - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + cache *cache.Cache + cacheTTL time.Duration + sqlNames *SQLNames // OAuth2 providers registry (multiple providers supported) oauth2Providers map[string]*OAuth2Provider @@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct { // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). // Partial overrides are supported: only set the fields you want to change. SQLNames *SQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) } func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { @@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO return &DatabaseAuthenticator{ db: db, + dbFactory: opts.DBFactory, cache: cacheInstance, cacheTTL: opts.CacheTTL, sqlNames: sqlNames, @@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO } } +func (a *DatabaseAuthenticator) getDB() *sql.DB { + a.dbMu.RLock() + defer a.dbMu.RUnlock() + return a.db +} + +func (a *DatabaseAuthenticator) reconnectDB() error { + if a.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := a.dbFactory() + if err != nil { + return err + } + a.dbMu.Lock() + a.db = newDB + a.dbMu.Unlock() + return nil +} + func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { // Convert LoginRequest to JSON reqJSON, err := json.Marshal(req) @@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + runLoginQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) + return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + } + err = runLoginQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runLoginQuery() + } + } if err != nil { return nil, fmt.Errorf("login query failed: %w", err) } @@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques var dataJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return nil, fmt.Errorf("register query failed: %w", err) } @@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e var dataJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return fmt.Errorf("logout query failed: %w", err) } @@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err var userJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) + err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("session query failed: %w", err) } @@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi var updatedUserJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) - _ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) + _ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) } // RefreshToken implements Refreshable interface @@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var userJSON sql.NullString // Get current session to pass to refresh query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) + err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("refresh token query failed: %w", err) } @@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var newUserJSON sql.NullString refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) - err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) + err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) if err != nil { return nil, fmt.Errorf("refresh token generation failed: %w", err) } @@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s type JWTAuthenticator struct { secretKey []byte db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) sqlNames *SQLNames } @@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA } } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator { + a.dbFactory = factory + return a +} + +func (a *JWTAuthenticator) getDB() *sql.DB { + a.dbMu.RLock() + defer a.dbMu.RUnlock() + return a.db +} + +func (a *JWTAuthenticator) reconnectDB() error { + if a.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := a.dbFactory() + if err != nil { + return err + } + a.dbMu.Lock() + a.db = newDB + a.dbMu.Unlock() + return nil +} + func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { var success bool var errorMsg sql.NullString var userJSON []byte - query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) - err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) + runLoginQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) + return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) + } + err := runLoginQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runLoginQuery() + } + } if err != nil { return nil, fmt.Errorf("login query failed: %w", err) } @@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout) - err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) + err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("logout query failed: %w", err) } @@ -513,14 +583,41 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { // All database operations go through stored procedures // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseColumnSecurityProvider struct { - db *sql.DB - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *SQLNames } func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider { return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } +func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider { + p.dbFactory = factory + return p +} + +func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabaseColumnSecurityProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { var rules []ColumnSecurity @@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, var errorMsg sql.NullString var rulesJSON []byte - query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) - err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) + return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return nil, fmt.Errorf("failed to load column security: %w", err) } @@ -578,21 +683,55 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, // All database operations go through stored procedures // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseRowSecurityProvider struct { - db *sql.DB - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *SQLNames } func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider { return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } +func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider { + p.dbFactory = factory + return p +} + +func (p *DatabaseRowSecurityProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabaseRowSecurityProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { var template string var hasBlock bool - query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) - - err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) + return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err) } @@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i // Helper functions // ================ +// isDBClosed reports whether err indicates the *sql.DB has been closed. +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + func parseRoles(rolesStr string) []string { if rolesStr == "" { return []string{} @@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + runPasskeyQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) + return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + } + err = runPasskeyQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runPasskeyQuery() + } + } if err != nil { return nil, fmt.Errorf("passkey login query failed: %w", err) }