package store import ( "context" "fmt" "regexp" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" pgxvec "github.com/pgvector/pgvector-go/pgx" "git.warky.dev/wdevs/amcs/internal/config" ) type DB struct { pool *pgxpool.Pool } func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) { poolConfig, err := pgxpool.ParseConfig(cfg.URL) if err != nil { return nil, fmt.Errorf("parse database config: %w", err) } poolConfig.MaxConns = cfg.MaxConns poolConfig.MinConns = cfg.MinConns poolConfig.MaxConnLifetime = cfg.MaxConnLifetime poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { return pgxvec.RegisterTypes(ctx, conn) } pool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { return nil, fmt.Errorf("create database pool: %w", err) } db := &DB{pool: pool} if err := db.Ping(ctx); err != nil { pool.Close() return nil, err } return db, nil } func (db *DB) Close() { if db == nil || db.pool == nil { return } db.pool.Close() } func (db *DB) Ping(ctx context.Context) error { if err := db.pool.Ping(ctx); err != nil { return fmt.Errorf("ping database: %w", err) } return nil } func (db *DB) Ready(ctx context.Context) error { readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() return db.Ping(readyCtx) } func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error { var hasVector bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil { return fmt.Errorf("verify vector extension: %w", err) } if !hasVector { return fmt.Errorf("vector extension is not installed") } var hasMatchThoughts bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil { return fmt.Errorf("verify match_thoughts function: %w", err) } if !hasMatchThoughts { return fmt.Errorf("match_thoughts function is missing") } var embeddingType string err := db.pool.QueryRow(ctx, ` select format_type(a.atttypid, a.atttypmod) from pg_attribute a join pg_class c on c.oid = a.attrelid join pg_namespace n on n.oid = c.relnamespace where n.nspname = 'public' and c.relname = 'thoughts' and a.attname = 'embedding' and not a.attisdropped `).Scan(&embeddingType) if err != nil { return fmt.Errorf("verify thoughts.embedding type: %w", err) } re := regexp.MustCompile(`vector\((\d+)\)`) matches := re.FindStringSubmatch(embeddingType) if len(matches) != 2 { return fmt.Errorf("unexpected embedding type %q", embeddingType) } var actualDimensions int if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil { return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err) } if actualDimensions != dimensions { return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions) } return nil }