feat(backfill): implement backfill tool for generating missing embeddings
This commit is contained in:
@@ -358,6 +358,140 @@ func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, em
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *DB) HasEmbeddingsForModel(ctx context.Context, model string, projectID *uuid.UUID) (bool, error) {
|
||||
args := []any{model}
|
||||
conditions := []string{
|
||||
"e.model = $1",
|
||||
"t.archived_at is null",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
|
||||
query := `select exists(select 1 from embeddings e join thoughts t on t.guid = e.thought_id where ` +
|
||||
strings.Join(conditions, " and ") + `)`
|
||||
|
||||
var exists bool
|
||||
if err := db.pool.QueryRow(ctx, query, args...).Scan(&exists); err != nil {
|
||||
return false, fmt.Errorf("check embeddings for model: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListThoughtsMissingEmbedding(ctx context.Context, model string, limit int, projectID *uuid.UUID, includeArchived bool, olderThanDays int) ([]thoughttypes.Thought, error) {
|
||||
args := []any{model}
|
||||
conditions := []string{"e.id is null"}
|
||||
|
||||
if !includeArchived {
|
||||
conditions = append(conditions, "t.archived_at is null")
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if olderThanDays > 0 {
|
||||
args = append(args, time.Now().Add(-time.Duration(olderThanDays)*24*time.Hour))
|
||||
conditions = append(conditions, fmt.Sprintf("t.created_at < $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select t.guid, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at
|
||||
from thoughts t
|
||||
left join embeddings e on e.thought_id = t.guid and e.model = $1
|
||||
where ` + strings.Join(conditions, " and ") + `
|
||||
order by t.created_at asc
|
||||
limit $` + fmt.Sprintf("%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list thoughts missing embedding: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
thoughts := make([]thoughttypes.Thought, 0, limit)
|
||||
for rows.Next() {
|
||||
var thought thoughttypes.Thought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan missing-embedding thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode missing-embedding metadata: %w", err)
|
||||
}
|
||||
thoughts = append(thoughts, thought)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate missing-embedding thoughts: %w", err)
|
||||
}
|
||||
return thoughts, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertEmbedding(ctx context.Context, thoughtID uuid.UUID, model string, embedding []float32) error {
|
||||
_, err := db.pool.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, thoughtID, model, len(embedding), pgvector.NewVector(embedding))
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert embedding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{query}
|
||||
conditions := []string{
|
||||
"t.archived_at is null",
|
||||
"to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
q := `
|
||||
select t.guid, t.content, t.metadata,
|
||||
ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity,
|
||||
t.created_at
|
||||
from thoughts t
|
||||
where ` + strings.Join(conditions, " and ") + `
|
||||
order by similarity desc
|
||||
limit $` + fmt.Sprintf("%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("text search thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan text search result: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode text search metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate text search results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||
type kv struct {
|
||||
key string
|
||||
|
||||
Reference in New Issue
Block a user