mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 17:36:23 +00:00
fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed. * Implemented mutex locks for safe concurrent access to the database connection. * Updated all database query methods to handle reconnection attempts on closed connections. * Enhanced error handling for database operations across multiple providers.
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
db *bun.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*bun.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
|
||||
b.dbFactory = factory
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunAdapter) getDB() *bun.DB {
|
||||
b.dbMu.RLock()
|
||||
defer b.dbMu.RUnlock()
|
||||
return b.db
|
||||
}
|
||||
|
||||
func (b *BunAdapter) reconnectDB() error {
|
||||
if b.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := b.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.dbMu.Lock()
|
||||
b.db = newDB
|
||||
b.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (b *BunAdapter) EnableQueryDebug() {
|
||||
b.db.AddQueryHook(&QueryDebugHook{})
|
||||
b.getDB().AddQueryHook(&QueryDebugHook{})
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
query: b.getDB().NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.db.NewInsert()}
|
||||
return &BunInsertQuery{query: b.getDB().NewInsert()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.db.NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||
return &BunDeleteQuery{query: b.getDB().NewDelete()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||
return fn(adapter)
|
||||
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
return b.getDB()
|
||||
}
|
||||
|
||||
func (b *BunAdapter) DriverName() string {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
driverName string
|
||||
}
|
||||
|
||||
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
logger.Info("PgSQL query debug mode - logging enabled via logger")
|
||||
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
||||
}
|
||||
}()
|
||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||
result, err := p.db.ExecContext(ctx, query, args...)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Exec failed: %v", err)
|
||||
return nil, err
|
||||
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
}
|
||||
}()
|
||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||
rows, err := p.db.QueryContext(ctx, query, args...)
|
||||
var rows *sql.Rows
|
||||
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Query failed: %v", err)
|
||||
return err
|
||||
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
}
|
||||
}()
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,11 +4,17 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||
@@ -14,8 +15,10 @@ import (
|
||||
|
||||
// SQLiteProvider implements Provider for SQLite databases
|
||||
type SQLiteProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewSQLiteProvider creates a new SQLite provider
|
||||
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
|
||||
// Execute a simple query to verify the database is accessible
|
||||
var result int
|
||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
||||
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
|
||||
err := run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
|
||||
@@ -74,7 +74,7 @@ func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithBasePath(basePath),
|
||||
server.WithStaticBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -695,7 +695,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
@@ -725,7 +725,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for _, preload := range preloads {
|
||||
for i := range preloads {
|
||||
preload := &preloads[i]
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -81,7 +81,8 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []UserKey
|
||||
for _, k := range s.keys {
|
||||
for i := range s.keys {
|
||||
k := &s.keys[i]
|
||||
if k.UserID != userID || !k.IsActive {
|
||||
continue
|
||||
}
|
||||
@@ -91,7 +92,7 @@ func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyT
|
||||
if keyType != "" && k.KeyType != keyType {
|
||||
continue
|
||||
}
|
||||
result = append(result, k)
|
||||
result = append(result, *k)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
@@ -22,6 +23,9 @@ type DatabaseKeyStoreOptions struct {
|
||||
CacheTTL time.Duration
|
||||
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||
SQLNames *KeyStoreSQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||
@@ -34,10 +38,12 @@ type DatabaseKeyStoreOptions struct {
|
||||
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||
type DatabaseKeyStore struct {
|
||||
db *sql.DB
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||
@@ -55,13 +61,34 @@ func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseK
|
||||
}
|
||||
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||
return &DatabaseKeyStore{
|
||||
db: db,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
db: db,
|
||||
dbFactory: o.DBFactory,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||
ks.dbMu.RLock()
|
||||
defer ks.dbMu.RUnlock()
|
||||
return ks.db
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||
if ks.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := ks.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ks.dbMu.Lock()
|
||||
ks.db = newDB
|
||||
ks.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||
// and returns the raw key once.
|
||||
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
@@ -100,7 +127,7 @@ func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest)
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||
if err = ks.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -123,7 +150,7 @@ func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType
|
||||
var keysJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||
if err := ks.db.QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -151,7 +178,7 @@ func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int
|
||||
var keyHash sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||
if err := ks.db.QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
@@ -184,9 +211,21 @@ func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyT
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
if err := ks.db.QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||
}
|
||||
if err := runQuery(); err != nil {
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||
|
||||
@@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
@@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
@@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
@@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
@@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
@@ -63,10 +63,12 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct {
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
@@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
}
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
runLoginQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
@@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
@@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
@@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA
|
||||
}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
|
||||
a.dbFactory = factory
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
runLoginQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
}
|
||||
err := runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -513,14 +583,41 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
// All database operations go through stored procedures
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
@@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
}
|
||||
@@ -578,21 +683,55 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
// All database operations go through stored procedures
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
|
||||
}
|
||||
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
|
||||
// Helper functions
|
||||
// ================
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
func parseRoles(rolesStr string) []string {
|
||||
if rolesStr == "" {
|
||||
return []string{}
|
||||
@@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
runPasskeyQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runPasskeyQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runPasskeyQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user