mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-12 10:53:52 +00:00
feat(db): add reconnect logic for database adapters
* Implement reconnect functionality in GormAdapter and other database adapters. * Introduce a DBFactory to handle reconnections. * Update health check logic to skip reconnects for transient failures. * Add tests for reconnect behavior in DatabaseAuthenticator.
This commit is contained in:
@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
|
||||
db := a.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
err := run(db)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = run(a.getDB())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
runLoginQuery := func() error {
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package security
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
|
||||
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
|
||||
t.Helper()
|
||||
|
||||
primaryDB, primaryMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create primary mock db: %v", err)
|
||||
}
|
||||
|
||||
reconnectDB, reconnectMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
primaryDB.Close()
|
||||
t.Fatalf("failed to create reconnect mock db: %v", err)
|
||||
}
|
||||
|
||||
cacheProvider := cache.NewMemoryProvider(&cache.Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
|
||||
Cache: cache.NewCache(cacheProvider),
|
||||
DBFactory: func() (*sql.DB, error) {
|
||||
return reconnectDB, nil
|
||||
},
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
_ = primaryDB.Close()
|
||||
_ = reconnectDB.Close()
|
||||
}
|
||||
|
||||
return auth, primaryMock, reconnectMock, cleanup
|
||||
}
|
||||
|
||||
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected authenticate to reconnect, got %v", err)
|
||||
}
|
||||
if userCtx.UserID != 7 {
|
||||
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Register reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := RegisterRequest{
|
||||
Username: "reconnect-register",
|
||||
Password: "password123",
|
||||
Email: "reconnect@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
resp, err := auth.Register(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected register to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "reconnected-register-token" {
|
||||
t.Fatalf("expected refreshed token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, nil)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
if err := auth.Logout(context.Background(), req); err != nil {
|
||||
t.Fatalf("expected logout to reconnect, got %v", err)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
refreshToken := "refresh-reconnect-token"
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnRows(sessionRows)
|
||||
|
||||
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
|
||||
WithArgs(refreshToken, sqlmock.AnyArg()).
|
||||
WillReturnRows(refreshRows)
|
||||
|
||||
resp, err := auth.RefreshToken(context.Background(), refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("expected refresh token to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "refreshed-token" {
|
||||
t.Fatalf("expected refreshed-token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test JWTAuthenticator
|
||||
func TestJWTAuthenticator(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
|
||||
Reference in New Issue
Block a user