package adapter import ( "context" "database/sql" "fmt" "sync" "time" "github.com/lib/pq" ) // PostgresConfig holds PostgreSQL connection configuration type PostgresConfig struct { Host string Port int Database string User string Password string SSLMode string MaxOpenConns int MaxIdleConns int ConnMaxLifetime time.Duration ConnMaxIdleTime time.Duration } // PostgresAdapter implements DBAdapter for PostgreSQL type PostgresAdapter struct { config PostgresConfig db *sql.DB listener *pq.Listener logger Logger mu sync.RWMutex } // NewPostgresAdapter creates a new PostgreSQL adapter func NewPostgresAdapter(config PostgresConfig, logger Logger) *PostgresAdapter { return &PostgresAdapter{ config: config, logger: logger, } } // Connect establishes a connection to PostgreSQL func (p *PostgresAdapter) Connect(ctx context.Context) error { connStr := p.buildConnectionString() db, err := sql.Open("postgres", connStr) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // Configure connection pool db.SetMaxOpenConns(p.config.MaxOpenConns) db.SetMaxIdleConns(p.config.MaxIdleConns) db.SetConnMaxLifetime(p.config.ConnMaxLifetime) db.SetConnMaxIdleTime(p.config.ConnMaxIdleTime) // Test connection if err := db.PingContext(ctx); err != nil { db.Close() return fmt.Errorf("failed to ping database: %w", err) } p.mu.Lock() p.db = db p.mu.Unlock() p.logger.Info("PostgreSQL connection established", "host", p.config.Host, "database", p.config.Database) return nil } // Close closes the database connection func (p *PostgresAdapter) Close() error { p.mu.Lock() defer p.mu.Unlock() if p.listener != nil { if err := p.listener.Close(); err != nil { p.logger.Error("failed to close listener", "error", err) } p.listener = nil } if p.db != nil { if err := p.db.Close(); err != nil { return fmt.Errorf("failed to close database: %w", err) } p.db = nil } p.logger.Info("PostgreSQL connection closed") return nil } // Ping checks if the database is reachable func (p *PostgresAdapter) Ping(ctx context.Context) error { p.mu.RLock() db := p.db p.mu.RUnlock() if db == nil { return fmt.Errorf("database connection not established") } return db.PingContext(ctx) } // Begin starts a new transaction func (p *PostgresAdapter) Begin(ctx context.Context) (DBTransaction, error) { p.mu.RLock() db := p.db p.mu.RUnlock() if db == nil { return nil, fmt.Errorf("database connection not established") } tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("failed to begin transaction: %w", err) } return &postgresTransaction{tx: tx}, nil } // Exec executes a query without returning rows func (p *PostgresAdapter) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { p.mu.RLock() db := p.db p.mu.RUnlock() if db == nil { return nil, fmt.Errorf("database connection not established") } return db.ExecContext(ctx, query, args...) } // QueryRow executes a query that returns at most one row func (p *PostgresAdapter) QueryRow(ctx context.Context, query string, args ...interface{}) DBRow { p.mu.RLock() db := p.db p.mu.RUnlock() if db == nil { return &postgresRow{err: fmt.Errorf("database connection not established")} } return &postgresRow{row: db.QueryRowContext(ctx, query, args...)} } // Query executes a query that returns rows func (p *PostgresAdapter) Query(ctx context.Context, query string, args ...interface{}) (DBRows, error) { p.mu.RLock() db := p.db p.mu.RUnlock() if db == nil { return nil, fmt.Errorf("database connection not established") } rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &postgresRows{rows: rows}, nil } // Listen starts listening on a PostgreSQL notification channel func (p *PostgresAdapter) Listen(ctx context.Context, channel string, handler NotificationHandler) error { connStr := p.buildConnectionString() reportProblem := func(ev pq.ListenerEventType, err error) { if err != nil { p.logger.Error("listener problem", "event", ev, "error", err) } } minReconn := 10 * time.Second maxReconn := 1 * time.Minute p.mu.Lock() p.listener = pq.NewListener(connStr, minReconn, maxReconn, reportProblem) listener := p.listener p.mu.Unlock() if err := listener.Listen(channel); err != nil { return fmt.Errorf("failed to listen on channel %s: %w", channel, err) } p.logger.Info("listening on channel", "channel", channel) // Start notification handler in goroutine go func() { for { select { case n := <-listener.Notify: if n != nil { handler(&Notification{ Channel: n.Channel, Payload: n.Extra, PID: n.BePid, }) } case <-ctx.Done(): p.logger.Info("stopping listener", "channel", channel) return case <-time.After(90 * time.Second): go listener.Ping() } } }() return nil } // Unlisten stops listening on a channel func (p *PostgresAdapter) Unlisten(ctx context.Context, channel string) error { p.mu.RLock() listener := p.listener p.mu.RUnlock() if listener == nil { return nil } return listener.Unlisten(channel) } // buildConnectionString builds a PostgreSQL connection string func (p *PostgresAdapter) buildConnectionString() string { sslMode := p.config.SSLMode if sslMode == "" { sslMode = "disable" } return fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", p.config.Host, p.config.Port, p.config.User, p.config.Password, p.config.Database, sslMode, ) } // postgresTransaction implements DBTransaction type postgresTransaction struct { tx *sql.Tx } func (t *postgresTransaction) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return t.tx.ExecContext(ctx, query, args...) } func (t *postgresTransaction) QueryRow(ctx context.Context, query string, args ...interface{}) DBRow { return &postgresRow{row: t.tx.QueryRowContext(ctx, query, args...)} } func (t *postgresTransaction) Query(ctx context.Context, query string, args ...interface{}) (DBRows, error) { rows, err := t.tx.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &postgresRows{rows: rows}, nil } func (t *postgresTransaction) Commit() error { return t.tx.Commit() } func (t *postgresTransaction) Rollback() error { return t.tx.Rollback() } // postgresRow implements DBRow type postgresRow struct { row *sql.Row err error } func (r *postgresRow) Scan(dest ...interface{}) error { if r.err != nil { return r.err } return r.row.Scan(dest...) } // postgresRows implements DBRows type postgresRows struct { rows *sql.Rows } func (r *postgresRows) Next() bool { return r.rows.Next() } func (r *postgresRows) Scan(dest ...interface{}) error { return r.rows.Scan(dest...) } func (r *postgresRows) Close() error { return r.rows.Close() } func (r *postgresRows) Err() error { return r.rows.Err() }