105 lines
2.7 KiB
Go
105 lines
2.7 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
pgxvec "github.com/pgvector/pgvector-go/pgx"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
)
|
|
|
|
type DB struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) {
|
|
poolConfig, err := pgxpool.ParseConfig(cfg.URL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse database config: %w", err)
|
|
}
|
|
|
|
poolConfig.MaxConns = cfg.MaxConns
|
|
poolConfig.MinConns = cfg.MinConns
|
|
poolConfig.MaxConnLifetime = cfg.MaxConnLifetime
|
|
poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime
|
|
poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
|
return pgxvec.RegisterTypes(ctx, conn)
|
|
}
|
|
|
|
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create database pool: %w", err)
|
|
}
|
|
|
|
db := &DB{pool: pool}
|
|
if err := db.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (db *DB) Close() {
|
|
if db == nil || db.pool == nil {
|
|
return
|
|
}
|
|
|
|
db.pool.Close()
|
|
}
|
|
|
|
func (db *DB) Ping(ctx context.Context) error {
|
|
if err := db.pool.Ping(ctx); err != nil {
|
|
return fmt.Errorf("ping database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) Ready(ctx context.Context) error {
|
|
readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
|
|
return db.Ping(readyCtx)
|
|
}
|
|
|
|
func (db *DB) VerifyRequirements(ctx context.Context) error {
|
|
var hasVector bool
|
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
|
return fmt.Errorf("verify vector extension: %w", err)
|
|
}
|
|
if !hasVector {
|
|
return fmt.Errorf("vector extension is not installed")
|
|
}
|
|
|
|
var hasMatchThoughts bool
|
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil {
|
|
return fmt.Errorf("verify match_thoughts function: %w", err)
|
|
}
|
|
if !hasMatchThoughts {
|
|
return fmt.Errorf("match_thoughts function is missing")
|
|
}
|
|
|
|
var hasEmbeddings bool
|
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
|
|
return fmt.Errorf("verify embeddings table: %w", err)
|
|
}
|
|
if !hasEmbeddings {
|
|
return fmt.Errorf("embeddings table is missing — run migrations")
|
|
}
|
|
|
|
var hasStoredFiles bool
|
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'stored_files')`).Scan(&hasStoredFiles); err != nil {
|
|
return fmt.Errorf("verify stored_files table: %w", err)
|
|
}
|
|
if !hasStoredFiles {
|
|
return fmt.Errorf("stored_files table is missing — run migrations")
|
|
}
|
|
|
|
return nil
|
|
}
|