feat(embeddings): add embedding model support and related changes

* Introduced EmbeddingModel method in Client and Provider interfaces
* Updated InsertThought and SearchThoughts methods to handle embedding models
* Created embeddings table and updated match_thoughts function for model filtering
* Removed embedding column from thoughts table
* Adjusted permissions for new embeddings table
This commit is contained in:
Hein
2026-03-25 16:25:41 +02:00
parent c8ca272b03
commit cebef3a07c
19 changed files with 259 additions and 88 deletions

View File

@@ -3,7 +3,6 @@ package store
import (
"context"
"fmt"
"regexp"
"time"
"github.com/jackc/pgx/v5"
@@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error {
return db.Ping(readyCtx)
}
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
func (db *DB) VerifyRequirements(ctx context.Context) error {
var hasVector bool
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
return fmt.Errorf("verify vector extension: %w", err)
@@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
return fmt.Errorf("match_thoughts function is missing")
}
var embeddingType string
err := db.pool.QueryRow(ctx, `
select format_type(a.atttypid, a.atttypmod)
from pg_attribute a
join pg_class c on c.oid = a.attrelid
join pg_namespace n on n.oid = c.relnamespace
where n.nspname = 'public'
and c.relname = 'thoughts'
and a.attname = 'embedding'
and not a.attisdropped
`).Scan(&embeddingType)
if err != nil {
return fmt.Errorf("verify thoughts.embedding type: %w", err)
var hasEmbeddings bool
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
return fmt.Errorf("verify embeddings table: %w", err)
}
re := regexp.MustCompile(`vector\((\d+)\)`)
matches := re.FindStringSubmatch(embeddingType)
if len(matches) != 2 {
return fmt.Errorf("unexpected embedding type %q", embeddingType)
}
var actualDimensions int
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
}
if actualDimensions != dimensions {
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
if !hasEmbeddings {
return fmt.Errorf("embeddings table is missing — run migrations")
}
return nil

View File

@@ -15,27 +15,51 @@ import (
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
metadata, err := json.Marshal(thought.Metadata)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
}
row := db.pool.QueryRow(ctx, `
insert into thoughts (content, embedding, metadata, project_id)
values ($1, $2, $3::jsonb, $4)
tx, err := db.pool.Begin(ctx)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
row := tx.QueryRow(ctx, `
insert into thoughts (content, metadata, project_id)
values ($1, $2::jsonb, $3)
returning id, created_at, updated_at
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
`, thought.Content, metadata, thought.ProjectID)
created := thought
created.Embedding = nil
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
}
if len(thought.Embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
}
return created, nil
}
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
filterJSON, err := json.Marshal(filter)
if err != nil {
return nil, fmt.Errorf("marshal search filter: %w", err)
@@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold
rows, err := db.pool.Query(ctx, `
select id, content, metadata, similarity, created_at
from match_thoughts($1, $2, $3, $4::jsonb)
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
from match_thoughts($1, $2, $3, $4::jsonb, $5)
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
if err != nil {
return nil, fmt.Errorf("search thoughts: %w", err)
}
@@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
row := db.pool.QueryRow(ctx, `
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
select id, content, metadata, project_id, archived_at, created_at, updated_at
from thoughts
where id = $1
`, id)
var thought thoughttypes.Thought
var embedding pgvector.Vector
var metadataBytes []byte
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
if err == pgx.ErrNoRows {
return thoughttypes.Thought{}, err
}
@@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
}
thought.Embedding = embedding.Slice()
return thought, nil
}
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
metadataBytes, err := json.Marshal(metadata)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
}
tag, err := db.pool.Exec(ctx, `
tx, err := db.pool.Begin(ctx)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
tag, err := tx.Exec(ctx, `
update thoughts
set content = $2,
embedding = $3,
metadata = $4::jsonb,
project_id = $5,
set content = $2,
metadata = $3::jsonb,
project_id = $4,
updated_at = now()
where id = $1
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
`, id, content, metadataBytes, projectID)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
}
@@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
return thoughttypes.Thought{}, pgx.ErrNoRows
}
if len(embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
}
return db.GetThought(ctx, id)
}
@@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in
return db.ListThoughts(ctx, filter)
}
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
args := []any{pgvector.NewVector(embedding), threshold}
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
conditions := []string{
"archived_at is null",
"1 - (embedding <=> $1) > $2",
"t.archived_at is null",
"1 - (e.embedding <=> $1) > $2",
"e.model = $3",
}
if projectID != nil {
args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
}
if excludeID != nil {
args = append(args, *excludeID)
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
}
args = append(args, limit)
query := `
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
from thoughts
select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
from thoughts t
join embeddings e on e.thought_id = t.id
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
order by embedding <=> $1
order by e.embedding <=> $1
limit $%d`, len(args))
rows, err := db.pool.Query(ctx, query, args...)