312 lines
6.9 KiB
Go
312 lines
6.9 KiB
Go
package adapter
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// PostgresConfig holds PostgreSQL connection configuration
|
|
type PostgresConfig struct {
|
|
Host string
|
|
Port int
|
|
Database string
|
|
User string
|
|
Password string
|
|
SSLMode string
|
|
MaxOpenConns int
|
|
MaxIdleConns int
|
|
ConnMaxLifetime time.Duration
|
|
ConnMaxIdleTime time.Duration
|
|
}
|
|
|
|
// PostgresAdapter implements DBAdapter for PostgreSQL
|
|
type PostgresAdapter struct {
|
|
config PostgresConfig
|
|
db *sql.DB
|
|
listener *pq.Listener
|
|
logger Logger
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewPostgresAdapter creates a new PostgreSQL adapter
|
|
func NewPostgresAdapter(config PostgresConfig, logger Logger) *PostgresAdapter {
|
|
return &PostgresAdapter{
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Connect establishes a connection to PostgreSQL
|
|
func (p *PostgresAdapter) Connect(ctx context.Context) error {
|
|
connStr := p.buildConnectionString()
|
|
|
|
db, err := sql.Open("postgres", connStr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// Configure connection pool
|
|
db.SetMaxOpenConns(p.config.MaxOpenConns)
|
|
db.SetMaxIdleConns(p.config.MaxIdleConns)
|
|
db.SetConnMaxLifetime(p.config.ConnMaxLifetime)
|
|
db.SetConnMaxIdleTime(p.config.ConnMaxIdleTime)
|
|
|
|
// Test connection
|
|
if err := db.PingContext(ctx); err != nil {
|
|
db.Close()
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
p.mu.Lock()
|
|
p.db = db
|
|
p.mu.Unlock()
|
|
|
|
p.logger.Info("PostgreSQL connection established", "host", p.config.Host, "database", p.config.Database)
|
|
return nil
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func (p *PostgresAdapter) Close() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if p.listener != nil {
|
|
if err := p.listener.Close(); err != nil {
|
|
p.logger.Error("failed to close listener", "error", err)
|
|
}
|
|
p.listener = nil
|
|
}
|
|
|
|
if p.db != nil {
|
|
if err := p.db.Close(); err != nil {
|
|
return fmt.Errorf("failed to close database: %w", err)
|
|
}
|
|
p.db = nil
|
|
}
|
|
|
|
p.logger.Info("PostgreSQL connection closed")
|
|
return nil
|
|
}
|
|
|
|
// Ping checks if the database is reachable
|
|
func (p *PostgresAdapter) Ping(ctx context.Context) error {
|
|
p.mu.RLock()
|
|
db := p.db
|
|
p.mu.RUnlock()
|
|
|
|
if db == nil {
|
|
return fmt.Errorf("database connection not established")
|
|
}
|
|
|
|
return db.PingContext(ctx)
|
|
}
|
|
|
|
// Begin starts a new transaction
|
|
func (p *PostgresAdapter) Begin(ctx context.Context) (DBTransaction, error) {
|
|
p.mu.RLock()
|
|
db := p.db
|
|
p.mu.RUnlock()
|
|
|
|
if db == nil {
|
|
return nil, fmt.Errorf("database connection not established")
|
|
}
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
return &postgresTransaction{tx: tx}, nil
|
|
}
|
|
|
|
// Exec executes a query without returning rows
|
|
func (p *PostgresAdapter) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
p.mu.RLock()
|
|
db := p.db
|
|
p.mu.RUnlock()
|
|
|
|
if db == nil {
|
|
return nil, fmt.Errorf("database connection not established")
|
|
}
|
|
|
|
return db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
// QueryRow executes a query that returns at most one row
|
|
func (p *PostgresAdapter) QueryRow(ctx context.Context, query string, args ...interface{}) DBRow {
|
|
p.mu.RLock()
|
|
db := p.db
|
|
p.mu.RUnlock()
|
|
|
|
if db == nil {
|
|
return &postgresRow{err: fmt.Errorf("database connection not established")}
|
|
}
|
|
|
|
return &postgresRow{row: db.QueryRowContext(ctx, query, args...)}
|
|
}
|
|
|
|
// Query executes a query that returns rows
|
|
func (p *PostgresAdapter) Query(ctx context.Context, query string, args ...interface{}) (DBRows, error) {
|
|
p.mu.RLock()
|
|
db := p.db
|
|
p.mu.RUnlock()
|
|
|
|
if db == nil {
|
|
return nil, fmt.Errorf("database connection not established")
|
|
}
|
|
|
|
rows, err := db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &postgresRows{rows: rows}, nil
|
|
}
|
|
|
|
// Listen starts listening on a PostgreSQL notification channel
|
|
func (p *PostgresAdapter) Listen(ctx context.Context, channel string, handler NotificationHandler) error {
|
|
connStr := p.buildConnectionString()
|
|
|
|
reportProblem := func(ev pq.ListenerEventType, err error) {
|
|
if err != nil {
|
|
p.logger.Error("listener problem", "event", ev, "error", err)
|
|
}
|
|
}
|
|
|
|
minReconn := 10 * time.Second
|
|
maxReconn := 1 * time.Minute
|
|
|
|
p.mu.Lock()
|
|
p.listener = pq.NewListener(connStr, minReconn, maxReconn, reportProblem)
|
|
listener := p.listener
|
|
p.mu.Unlock()
|
|
|
|
if err := listener.Listen(channel); err != nil {
|
|
return fmt.Errorf("failed to listen on channel %s: %w", channel, err)
|
|
}
|
|
|
|
p.logger.Info("listening on channel", "channel", channel)
|
|
|
|
// Start notification handler in goroutine
|
|
go func() {
|
|
for {
|
|
select {
|
|
case n := <-listener.Notify:
|
|
if n != nil {
|
|
handler(&Notification{
|
|
Channel: n.Channel,
|
|
Payload: n.Extra,
|
|
PID: n.BePid,
|
|
})
|
|
}
|
|
case <-ctx.Done():
|
|
p.logger.Info("stopping listener", "channel", channel)
|
|
return
|
|
case <-time.After(90 * time.Second):
|
|
go listener.Ping()
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Unlisten stops listening on a channel
|
|
func (p *PostgresAdapter) Unlisten(ctx context.Context, channel string) error {
|
|
p.mu.RLock()
|
|
listener := p.listener
|
|
p.mu.RUnlock()
|
|
|
|
if listener == nil {
|
|
return nil
|
|
}
|
|
|
|
return listener.Unlisten(channel)
|
|
}
|
|
|
|
// buildConnectionString builds a PostgreSQL connection string
|
|
func (p *PostgresAdapter) buildConnectionString() string {
|
|
sslMode := p.config.SSLMode
|
|
if sslMode == "" {
|
|
sslMode = "disable"
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
|
p.config.Host,
|
|
p.config.Port,
|
|
p.config.User,
|
|
p.config.Password,
|
|
p.config.Database,
|
|
sslMode,
|
|
)
|
|
}
|
|
|
|
// postgresTransaction implements DBTransaction
|
|
type postgresTransaction struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (t *postgresTransaction) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
return t.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (t *postgresTransaction) QueryRow(ctx context.Context, query string, args ...interface{}) DBRow {
|
|
return &postgresRow{row: t.tx.QueryRowContext(ctx, query, args...)}
|
|
}
|
|
|
|
func (t *postgresTransaction) Query(ctx context.Context, query string, args ...interface{}) (DBRows, error) {
|
|
rows, err := t.tx.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &postgresRows{rows: rows}, nil
|
|
}
|
|
|
|
func (t *postgresTransaction) Commit() error {
|
|
return t.tx.Commit()
|
|
}
|
|
|
|
func (t *postgresTransaction) Rollback() error {
|
|
return t.tx.Rollback()
|
|
}
|
|
|
|
// postgresRow implements DBRow
|
|
type postgresRow struct {
|
|
row *sql.Row
|
|
err error
|
|
}
|
|
|
|
func (r *postgresRow) Scan(dest ...interface{}) error {
|
|
if r.err != nil {
|
|
return r.err
|
|
}
|
|
return r.row.Scan(dest...)
|
|
}
|
|
|
|
// postgresRows implements DBRows
|
|
type postgresRows struct {
|
|
rows *sql.Rows
|
|
}
|
|
|
|
func (r *postgresRows) Next() bool {
|
|
return r.rows.Next()
|
|
}
|
|
|
|
func (r *postgresRows) Scan(dest ...interface{}) error {
|
|
return r.rows.Scan(dest...)
|
|
}
|
|
|
|
func (r *postgresRows) Close() error {
|
|
return r.rows.Close()
|
|
}
|
|
|
|
func (r *postgresRows) Err() error {
|
|
return r.rows.Err()
|
|
}
|