package store import ( "context" "database/sql" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" pgxvec "github.com/pgvector/pgvector-go/pgx" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" "git.warky.dev/wdevs/amcs/internal/config" ) type DB struct { pool *pgxpool.Pool bun *bun.DB } func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) { poolConfig, err := pgxpool.ParseConfig(cfg.URL) if err != nil { return nil, fmt.Errorf("parse database config: %w", err) } poolConfig.MaxConns = cfg.MaxConns poolConfig.MinConns = cfg.MinConns poolConfig.MaxConnLifetime = cfg.MaxConnLifetime poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { return pgxvec.RegisterTypes(ctx, conn) } pool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { return nil, fmt.Errorf("create database pool: %w", err) } bunSQLDB := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(cfg.URL))) bunSQLDB.SetMaxOpenConns(int(cfg.MaxConns)) bunSQLDB.SetMaxIdleConns(int(cfg.MinConns)) bunSQLDB.SetConnMaxLifetime(cfg.MaxConnLifetime) bunSQLDB.SetConnMaxIdleTime(cfg.MaxConnIdleTime) db := &DB{ pool: pool, bun: bun.NewDB(bunSQLDB, pgdialect.New()), } if err := db.Ping(ctx); err != nil { if db.bun != nil { _ = db.bun.Close() } pool.Close() return nil, err } return db, nil } func (db *DB) Close() { if db == nil { return } if db.bun != nil { _ = db.bun.Close() } if db.pool != nil { db.pool.Close() } } func (db *DB) Ping(ctx context.Context) error { if err := db.pool.Ping(ctx); err != nil { return fmt.Errorf("ping database: %w", err) } return nil } func (db *DB) Ready(ctx context.Context) error { readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() return db.Ping(readyCtx) } func (db *DB) VerifyRequirements(ctx context.Context) error { var hasVector bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil { return fmt.Errorf("verify vector extension: %w", err) } if !hasVector { return fmt.Errorf("vector extension is not installed") } var hasMatchThoughts bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil { return fmt.Errorf("verify match_thoughts function: %w", err) } if !hasMatchThoughts { return fmt.Errorf("match_thoughts function is missing") } var hasEmbeddings bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil { return fmt.Errorf("verify embeddings table: %w", err) } if !hasEmbeddings { return fmt.Errorf("embeddings table is missing — run migrations") } var hasStoredFiles bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'stored_files')`).Scan(&hasStoredFiles); err != nil { return fmt.Errorf("verify stored_files table: %w", err) } if !hasStoredFiles { return fmt.Errorf("stored_files table is missing — run migrations") } return nil } func (db *DB) Bun() *bun.DB { if db == nil { return nil } return db.bun }