feat(backfill): implement backfill tool for generating missing embeddings

This commit is contained in:
2026-03-26 22:45:28 +02:00
parent 1dde7f233d
commit f4ef0e9163
19 changed files with 575 additions and 37 deletions

View File

@@ -358,6 +358,140 @@ func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, em
return results, nil
}
func (db *DB) HasEmbeddingsForModel(ctx context.Context, model string, projectID *uuid.UUID) (bool, error) {
args := []any{model}
conditions := []string{
"e.model = $1",
"t.archived_at is null",
}
if projectID != nil {
args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
}
query := `select exists(select 1 from embeddings e join thoughts t on t.guid = e.thought_id where ` +
strings.Join(conditions, " and ") + `)`
var exists bool
if err := db.pool.QueryRow(ctx, query, args...).Scan(&exists); err != nil {
return false, fmt.Errorf("check embeddings for model: %w", err)
}
return exists, nil
}
func (db *DB) ListThoughtsMissingEmbedding(ctx context.Context, model string, limit int, projectID *uuid.UUID, includeArchived bool, olderThanDays int) ([]thoughttypes.Thought, error) {
args := []any{model}
conditions := []string{"e.id is null"}
if !includeArchived {
conditions = append(conditions, "t.archived_at is null")
}
if projectID != nil {
args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
}
if olderThanDays > 0 {
args = append(args, time.Now().Add(-time.Duration(olderThanDays)*24*time.Hour))
conditions = append(conditions, fmt.Sprintf("t.created_at < $%d", len(args)))
}
args = append(args, limit)
query := `
select t.guid, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at
from thoughts t
left join embeddings e on e.thought_id = t.guid and e.model = $1
where ` + strings.Join(conditions, " and ") + `
order by t.created_at asc
limit $` + fmt.Sprintf("%d", len(args))
rows, err := db.pool.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list thoughts missing embedding: %w", err)
}
defer rows.Close()
thoughts := make([]thoughttypes.Thought, 0, limit)
for rows.Next() {
var thought thoughttypes.Thought
var metadataBytes []byte
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan missing-embedding thought: %w", err)
}
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
return nil, fmt.Errorf("decode missing-embedding metadata: %w", err)
}
thoughts = append(thoughts, thought)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate missing-embedding thoughts: %w", err)
}
return thoughts, nil
}
func (db *DB) UpsertEmbedding(ctx context.Context, thoughtID uuid.UUID, model string, embedding []float32) error {
_, err := db.pool.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, thoughtID, model, len(embedding), pgvector.NewVector(embedding))
if err != nil {
return fmt.Errorf("upsert embedding: %w", err)
}
return nil
}
func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
args := []any{query}
conditions := []string{
"t.archived_at is null",
"to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)",
}
if projectID != nil {
args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
}
if excludeID != nil {
args = append(args, *excludeID)
conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args)))
}
args = append(args, limit)
q := `
select t.guid, t.content, t.metadata,
ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity,
t.created_at
from thoughts t
where ` + strings.Join(conditions, " and ") + `
order by similarity desc
limit $` + fmt.Sprintf("%d", len(args))
rows, err := db.pool.Query(ctx, q, args...)
if err != nil {
return nil, fmt.Errorf("text search thoughts: %w", err)
}
defer rows.Close()
results := make([]thoughttypes.SearchResult, 0, limit)
for rows.Next() {
var result thoughttypes.SearchResult
var metadataBytes []byte
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
return nil, fmt.Errorf("scan text search result: %w", err)
}
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
return nil, fmt.Errorf("decode text search metadata: %w", err)
}
results = append(results, result)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate text search results: %w", err)
}
return results, nil
}
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
type kv struct {
key string