feat(embeddings): add embedding model support and related changes
* Introduced EmbeddingModel method in Client and Provider interfaces * Updated InsertThought and SearchThoughts methods to handle embedding models * Created embeddings table and updated match_thoughts function for model filtering * Removed embedding column from thoughts table * Adjusted permissions for new embeddings table
This commit is contained in:
@@ -3,7 +3,6 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
@@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error {
|
||||
return db.Ping(readyCtx)
|
||||
}
|
||||
|
||||
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
func (db *DB) VerifyRequirements(ctx context.Context) error {
|
||||
var hasVector bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
||||
return fmt.Errorf("verify vector extension: %w", err)
|
||||
@@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
return fmt.Errorf("match_thoughts function is missing")
|
||||
}
|
||||
|
||||
var embeddingType string
|
||||
err := db.pool.QueryRow(ctx, `
|
||||
select format_type(a.atttypid, a.atttypmod)
|
||||
from pg_attribute a
|
||||
join pg_class c on c.oid = a.attrelid
|
||||
join pg_namespace n on n.oid = c.relnamespace
|
||||
where n.nspname = 'public'
|
||||
and c.relname = 'thoughts'
|
||||
and a.attname = 'embedding'
|
||||
and not a.attisdropped
|
||||
`).Scan(&embeddingType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify thoughts.embedding type: %w", err)
|
||||
var hasEmbeddings bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
|
||||
return fmt.Errorf("verify embeddings table: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`vector\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(embeddingType)
|
||||
if len(matches) != 2 {
|
||||
return fmt.Errorf("unexpected embedding type %q", embeddingType)
|
||||
}
|
||||
|
||||
var actualDimensions int
|
||||
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
|
||||
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
|
||||
}
|
||||
if actualDimensions != dimensions {
|
||||
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
|
||||
if !hasEmbeddings {
|
||||
return fmt.Errorf("embeddings table is missing — run migrations")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -15,27 +15,51 @@ import (
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
|
||||
metadata, err := json.Marshal(thought.Metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into thoughts (content, embedding, metadata, project_id)
|
||||
values ($1, $2, $3::jsonb, $4)
|
||||
tx, err := db.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
row := tx.QueryRow(ctx, `
|
||||
insert into thoughts (content, metadata, project_id)
|
||||
values ($1, $2::jsonb, $3)
|
||||
returning id, created_at, updated_at
|
||||
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
|
||||
`, thought.Content, metadata, thought.ProjectID)
|
||||
|
||||
created := thought
|
||||
created.Embedding = nil
|
||||
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
||||
}
|
||||
|
||||
if len(thought.Embedding) > 0 && embeddingModel != "" {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal search filter: %w", err)
|
||||
@@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold
|
||||
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select id, content, metadata, similarity, created_at
|
||||
from match_thoughts($1, $2, $3, $4::jsonb)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
|
||||
from match_thoughts($1, $2, $3, $4::jsonb, $5)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search thoughts: %w", err)
|
||||
}
|
||||
@@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
||||
|
||||
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
|
||||
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
where id = $1
|
||||
`, id)
|
||||
|
||||
var thought thoughttypes.Thought
|
||||
var embedding pgvector.Vector
|
||||
var metadataBytes []byte
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Thought{}, err
|
||||
}
|
||||
@@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
||||
}
|
||||
thought.Embedding = embedding.Slice()
|
||||
|
||||
return thought, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
||||
}
|
||||
|
||||
tag, err := db.pool.Exec(ctx, `
|
||||
tx, err := db.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
tag, err := tx.Exec(ctx, `
|
||||
update thoughts
|
||||
set content = $2,
|
||||
embedding = $3,
|
||||
metadata = $4::jsonb,
|
||||
project_id = $5,
|
||||
set content = $2,
|
||||
metadata = $3::jsonb,
|
||||
project_id = $4,
|
||||
updated_at = now()
|
||||
where id = $1
|
||||
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
|
||||
`, id, content, metadataBytes, projectID)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
||||
}
|
||||
@@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
|
||||
return thoughttypes.Thought{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
if len(embedding) > 0 && embeddingModel != "" {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
|
||||
}
|
||||
|
||||
return db.GetThought(ctx, id)
|
||||
}
|
||||
|
||||
@@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in
|
||||
return db.ListThoughts(ctx, filter)
|
||||
}
|
||||
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold}
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
|
||||
conditions := []string{
|
||||
"archived_at is null",
|
||||
"1 - (embedding <=> $1) > $2",
|
||||
"t.archived_at is null",
|
||||
"1 - (e.embedding <=> $1) > $2",
|
||||
"e.model = $3",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
|
||||
conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
|
||||
from thoughts
|
||||
select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
|
||||
from thoughts t
|
||||
join embeddings e on e.thought_id = t.id
|
||||
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
||||
order by embedding <=> $1
|
||||
order by e.embedding <=> $1
|
||||
limit $%d`, len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
|
||||
Reference in New Issue
Block a user