feat(db): add reconnect logic for database adapters

* Implement reconnect functionality in GormAdapter and other database adapters.
* Introduce a DBFactory to handle reconnections.
* Update health check logic to skip reconnects for transient failures.
* Add tests for reconnect behavior in DatabaseAuthenticator.
This commit is contained in:
Hein
2026-04-10 11:18:39 +02:00
parent 2afee9d238
commit 16a960d973
8 changed files with 728 additions and 57 deletions

View File

@@ -143,6 +143,22 @@ func (a *DatabaseAuthenticator) reconnectDB() error {
return nil
}
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
db := a.getDB()
if db == nil {
return fmt.Errorf("database connection is nil")
}
err := run(db)
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = run(a.getDB())
}
}
return err
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
@@ -154,16 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
var errorMsg sql.NullString
var dataJSON sql.NullString
runLoginQuery := func() error {
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -196,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
}
@@ -229,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
var errorMsg sql.NullString
var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -303,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString
var userJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
@@ -379,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
var errorMsg sql.NullString
var updatedUserJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
})
}
// RefreshToken implements Refreshable interface
@@ -390,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var errorMsg sql.NullString
var userJSON sql.NullString
// Get current session to pass to refresh
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
}
@@ -407,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var newErrorMsg sql.NullString
var newUserJSON sql.NullString
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
})
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
}