* Introduced EmbeddingModel method in Client and Provider interfaces * Updated InsertThought and SearchThoughts methods to handle embedding models * Created embeddings table and updated match_thoughts function for model filtering * Removed embedding column from thoughts table * Adjusted permissions for new embeddings table
392 lines
12 KiB
Go
392 lines
12 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/pgvector/pgvector-go"
|
|
|
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
|
)
|
|
|
|
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
|
|
metadata, err := json.Marshal(thought.Metadata)
|
|
if err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
|
}
|
|
|
|
tx, err := db.pool.Begin(ctx)
|
|
if err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
row := tx.QueryRow(ctx, `
|
|
insert into thoughts (content, metadata, project_id)
|
|
values ($1, $2::jsonb, $3)
|
|
returning id, created_at, updated_at
|
|
`, thought.Content, metadata, thought.ProjectID)
|
|
|
|
created := thought
|
|
created.Embedding = nil
|
|
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
|
}
|
|
|
|
if len(thought.Embedding) > 0 && embeddingModel != "" {
|
|
if _, err := tx.Exec(ctx, `
|
|
insert into embeddings (thought_id, model, dim, embedding)
|
|
values ($1, $2, $3, $4)
|
|
on conflict (thought_id, model) do update
|
|
set embedding = excluded.embedding,
|
|
dim = excluded.dim,
|
|
updated_at = now()
|
|
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
|
}
|
|
|
|
return created, nil
|
|
}
|
|
|
|
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
|
filterJSON, err := json.Marshal(filter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal search filter: %w", err)
|
|
}
|
|
|
|
rows, err := db.pool.Query(ctx, `
|
|
select id, content, metadata, similarity, created_at
|
|
from match_thoughts($1, $2, $3, $4::jsonb, $5)
|
|
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("search thoughts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
results := make([]thoughttypes.SearchResult, 0, limit)
|
|
for rows.Next() {
|
|
var result thoughttypes.SearchResult
|
|
var metadataBytes []byte
|
|
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan search result: %w", err)
|
|
}
|
|
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
|
return nil, fmt.Errorf("decode search metadata: %w", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate search results: %w", err)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (db *DB) ListThoughts(ctx context.Context, filter thoughttypes.ListFilter) ([]thoughttypes.Thought, error) {
|
|
args := make([]any, 0, 6)
|
|
conditions := []string{}
|
|
|
|
if !filter.IncludeArchived {
|
|
conditions = append(conditions, "archived_at is null")
|
|
}
|
|
if value := strings.TrimSpace(filter.Type); value != "" {
|
|
args = append(args, value)
|
|
conditions = append(conditions, fmt.Sprintf("metadata->>'type' = $%d", len(args)))
|
|
}
|
|
if value := strings.TrimSpace(filter.Topic); value != "" {
|
|
args = append(args, value)
|
|
conditions = append(conditions, fmt.Sprintf("metadata->'topics' ? $%d", len(args)))
|
|
}
|
|
if value := strings.TrimSpace(filter.Person); value != "" {
|
|
args = append(args, value)
|
|
conditions = append(conditions, fmt.Sprintf("metadata->'people' ? $%d", len(args)))
|
|
}
|
|
if filter.Days > 0 {
|
|
args = append(args, time.Now().Add(-time.Duration(filter.Days)*24*time.Hour))
|
|
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)))
|
|
}
|
|
if filter.ProjectID != nil {
|
|
args = append(args, *filter.ProjectID)
|
|
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
|
}
|
|
|
|
query := `
|
|
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
|
from thoughts
|
|
`
|
|
if len(conditions) > 0 {
|
|
query += " where " + strings.Join(conditions, " and ")
|
|
}
|
|
|
|
args = append(args, filter.Limit)
|
|
query += fmt.Sprintf(" order by created_at desc limit $%d", len(args))
|
|
|
|
rows, err := db.pool.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list thoughts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
thoughts := make([]thoughttypes.Thought, 0, filter.Limit)
|
|
for rows.Next() {
|
|
var thought thoughttypes.Thought
|
|
var metadataBytes []byte
|
|
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan listed thought: %w", err)
|
|
}
|
|
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
|
return nil, fmt.Errorf("decode listed metadata: %w", err)
|
|
}
|
|
thoughts = append(thoughts, thought)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate listed thoughts: %w", err)
|
|
}
|
|
|
|
return thoughts, nil
|
|
}
|
|
|
|
func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
|
var total int
|
|
if err := db.pool.QueryRow(ctx, `select count(*) from thoughts where archived_at is null`).Scan(&total); err != nil {
|
|
return thoughttypes.ThoughtStats{}, fmt.Errorf("count thoughts: %w", err)
|
|
}
|
|
|
|
rows, err := db.pool.Query(ctx, `select metadata from thoughts where archived_at is null`)
|
|
if err != nil {
|
|
return thoughttypes.ThoughtStats{}, fmt.Errorf("query stats metadata: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
stats := thoughttypes.ThoughtStats{
|
|
TotalCount: total,
|
|
TypeCounts: map[string]int{},
|
|
}
|
|
topics := map[string]int{}
|
|
people := map[string]int{}
|
|
|
|
for rows.Next() {
|
|
var metadataBytes []byte
|
|
if err := rows.Scan(&metadataBytes); err != nil {
|
|
return thoughttypes.ThoughtStats{}, fmt.Errorf("scan stats metadata: %w", err)
|
|
}
|
|
|
|
var metadata thoughttypes.ThoughtMetadata
|
|
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
|
return thoughttypes.ThoughtStats{}, fmt.Errorf("decode stats metadata: %w", err)
|
|
}
|
|
|
|
stats.TypeCounts[metadata.Type]++
|
|
for _, topic := range metadata.Topics {
|
|
topics[topic]++
|
|
}
|
|
for _, person := range metadata.People {
|
|
people[person]++
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return thoughttypes.ThoughtStats{}, fmt.Errorf("iterate stats metadata: %w", err)
|
|
}
|
|
|
|
stats.TopTopics = topCounts(topics, 10)
|
|
stats.TopPeople = topCounts(people, 10)
|
|
return stats, nil
|
|
}
|
|
|
|
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
|
row := db.pool.QueryRow(ctx, `
|
|
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
|
from thoughts
|
|
where id = $1
|
|
`, id)
|
|
|
|
var thought thoughttypes.Thought
|
|
var metadataBytes []byte
|
|
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return thoughttypes.Thought{}, err
|
|
}
|
|
return thoughttypes.Thought{}, fmt.Errorf("get thought: %w", err)
|
|
}
|
|
|
|
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
|
}
|
|
|
|
return thought, nil
|
|
}
|
|
|
|
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
|
metadataBytes, err := json.Marshal(metadata)
|
|
if err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
|
}
|
|
|
|
tx, err := db.pool.Begin(ctx)
|
|
if err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
tag, err := tx.Exec(ctx, `
|
|
update thoughts
|
|
set content = $2,
|
|
metadata = $3::jsonb,
|
|
project_id = $4,
|
|
updated_at = now()
|
|
where id = $1
|
|
`, id, content, metadataBytes, projectID)
|
|
if err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return thoughttypes.Thought{}, pgx.ErrNoRows
|
|
}
|
|
|
|
if len(embedding) > 0 && embeddingModel != "" {
|
|
if _, err := tx.Exec(ctx, `
|
|
insert into embeddings (thought_id, model, dim, embedding)
|
|
values ($1, $2, $3, $4)
|
|
on conflict (thought_id, model) do update
|
|
set embedding = excluded.embedding,
|
|
dim = excluded.dim,
|
|
updated_at = now()
|
|
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
|
|
}
|
|
|
|
return db.GetThought(ctx, id)
|
|
}
|
|
|
|
func (db *DB) DeleteThought(ctx context.Context, id uuid.UUID) error {
|
|
tag, err := db.pool.Exec(ctx, `delete from thoughts where id = $1`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete thought: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) ArchiveThought(ctx context.Context, id uuid.UUID) error {
|
|
tag, err := db.pool.Exec(ctx, `update thoughts set archived_at = now(), updated_at = now() where id = $1`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("archive thought: %w", err)
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit int, days int) ([]thoughttypes.Thought, error) {
|
|
filter := thoughttypes.ListFilter{
|
|
Limit: limit,
|
|
ProjectID: projectID,
|
|
Days: days,
|
|
IncludeArchived: false,
|
|
}
|
|
return db.ListThoughts(ctx, filter)
|
|
}
|
|
|
|
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
|
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
|
|
conditions := []string{
|
|
"t.archived_at is null",
|
|
"1 - (e.embedding <=> $1) > $2",
|
|
"e.model = $3",
|
|
}
|
|
if projectID != nil {
|
|
args = append(args, *projectID)
|
|
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
|
}
|
|
if excludeID != nil {
|
|
args = append(args, *excludeID)
|
|
conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
|
|
}
|
|
args = append(args, limit)
|
|
|
|
query := `
|
|
select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
|
|
from thoughts t
|
|
join embeddings e on e.thought_id = t.id
|
|
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
|
order by e.embedding <=> $1
|
|
limit $%d`, len(args))
|
|
|
|
rows, err := db.pool.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("search similar thoughts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
results := make([]thoughttypes.SearchResult, 0, limit)
|
|
for rows.Next() {
|
|
var result thoughttypes.SearchResult
|
|
var metadataBytes []byte
|
|
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan similar thought: %w", err)
|
|
}
|
|
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
|
return nil, fmt.Errorf("decode similar thought metadata: %w", err)
|
|
}
|
|
results = append(results, result)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate similar thoughts: %w", err)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
|
type kv struct {
|
|
key string
|
|
count int
|
|
}
|
|
|
|
pairs := make([]kv, 0, len(in))
|
|
for key, count := range in {
|
|
if strings.TrimSpace(key) == "" {
|
|
continue
|
|
}
|
|
pairs = append(pairs, kv{key: key, count: count})
|
|
}
|
|
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
if pairs[i].count == pairs[j].count {
|
|
return pairs[i].key < pairs[j].key
|
|
}
|
|
return pairs[i].count > pairs[j].count
|
|
})
|
|
|
|
if limit > 0 && len(pairs) > limit {
|
|
pairs = pairs[:limit]
|
|
}
|
|
|
|
out := make([]thoughttypes.KeyCount, 0, len(pairs))
|
|
for _, pair := range pairs {
|
|
out = append(out, thoughttypes.KeyCount{Key: pair.key, Count: pair.count})
|
|
}
|
|
return out
|
|
}
|