Files
amcs/internal/store/thoughts.go
Hein 8d0a91a961 feat(llm): add LLM integration instructions and handler
* Serve LLM instructions at `/llm`
* Include markdown content for memory instructions
* Update README with LLM integration details
* Add tests for LLM instructions handler
* Modify database migrations to use GUIDs for thoughts and projects
2026-03-25 18:02:42 +02:00

392 lines
12 KiB
Go

package store
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/pgvector/pgvector-go"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
metadata, err := json.Marshal(thought.Metadata)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
}
tx, err := db.pool.Begin(ctx)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
row := tx.QueryRow(ctx, `
insert into thoughts (content, metadata, project_id)
values ($1, $2::jsonb, $3)
returning guid, created_at, updated_at
`, thought.Content, metadata, thought.ProjectID)
created := thought
created.Embedding = nil
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
}
if len(thought.Embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
}
return created, nil
}
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
filterJSON, err := json.Marshal(filter)
if err != nil {
return nil, fmt.Errorf("marshal search filter: %w", err)
}
rows, err := db.pool.Query(ctx, `
select id, content, metadata, similarity, created_at
from match_thoughts($1, $2, $3, $4::jsonb, $5)
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
if err != nil {
return nil, fmt.Errorf("search thoughts: %w", err)
}
defer rows.Close()
results := make([]thoughttypes.SearchResult, 0, limit)
for rows.Next() {
var result thoughttypes.SearchResult
var metadataBytes []byte
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
return nil, fmt.Errorf("scan search result: %w", err)
}
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
return nil, fmt.Errorf("decode search metadata: %w", err)
}
results = append(results, result)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate search results: %w", err)
}
return results, nil
}
func (db *DB) ListThoughts(ctx context.Context, filter thoughttypes.ListFilter) ([]thoughttypes.Thought, error) {
args := make([]any, 0, 6)
conditions := []string{}
if !filter.IncludeArchived {
conditions = append(conditions, "archived_at is null")
}
if value := strings.TrimSpace(filter.Type); value != "" {
args = append(args, value)
conditions = append(conditions, fmt.Sprintf("metadata->>'type' = $%d", len(args)))
}
if value := strings.TrimSpace(filter.Topic); value != "" {
args = append(args, value)
conditions = append(conditions, fmt.Sprintf("metadata->'topics' ? $%d", len(args)))
}
if value := strings.TrimSpace(filter.Person); value != "" {
args = append(args, value)
conditions = append(conditions, fmt.Sprintf("metadata->'people' ? $%d", len(args)))
}
if filter.Days > 0 {
args = append(args, time.Now().Add(-time.Duration(filter.Days)*24*time.Hour))
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)))
}
if filter.ProjectID != nil {
args = append(args, *filter.ProjectID)
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
}
query := `
select guid, content, metadata, project_id, archived_at, created_at, updated_at
from thoughts
`
if len(conditions) > 0 {
query += " where " + strings.Join(conditions, " and ")
}
args = append(args, filter.Limit)
query += fmt.Sprintf(" order by created_at desc limit $%d", len(args))
rows, err := db.pool.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list thoughts: %w", err)
}
defer rows.Close()
thoughts := make([]thoughttypes.Thought, 0, filter.Limit)
for rows.Next() {
var thought thoughttypes.Thought
var metadataBytes []byte
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan listed thought: %w", err)
}
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
return nil, fmt.Errorf("decode listed metadata: %w", err)
}
thoughts = append(thoughts, thought)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate listed thoughts: %w", err)
}
return thoughts, nil
}
func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
var total int
if err := db.pool.QueryRow(ctx, `select count(*) from thoughts where archived_at is null`).Scan(&total); err != nil {
return thoughttypes.ThoughtStats{}, fmt.Errorf("count thoughts: %w", err)
}
rows, err := db.pool.Query(ctx, `select metadata from thoughts where archived_at is null`)
if err != nil {
return thoughttypes.ThoughtStats{}, fmt.Errorf("query stats metadata: %w", err)
}
defer rows.Close()
stats := thoughttypes.ThoughtStats{
TotalCount: total,
TypeCounts: map[string]int{},
}
topics := map[string]int{}
people := map[string]int{}
for rows.Next() {
var metadataBytes []byte
if err := rows.Scan(&metadataBytes); err != nil {
return thoughttypes.ThoughtStats{}, fmt.Errorf("scan stats metadata: %w", err)
}
var metadata thoughttypes.ThoughtMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return thoughttypes.ThoughtStats{}, fmt.Errorf("decode stats metadata: %w", err)
}
stats.TypeCounts[metadata.Type]++
for _, topic := range metadata.Topics {
topics[topic]++
}
for _, person := range metadata.People {
people[person]++
}
}
if err := rows.Err(); err != nil {
return thoughttypes.ThoughtStats{}, fmt.Errorf("iterate stats metadata: %w", err)
}
stats.TopTopics = topCounts(topics, 10)
stats.TopPeople = topCounts(people, 10)
return stats, nil
}
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
row := db.pool.QueryRow(ctx, `
select guid, content, metadata, project_id, archived_at, created_at, updated_at
from thoughts
where guid = $1
`, id)
var thought thoughttypes.Thought
var metadataBytes []byte
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
if err == pgx.ErrNoRows {
return thoughttypes.Thought{}, err
}
return thoughttypes.Thought{}, fmt.Errorf("get thought: %w", err)
}
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
}
return thought, nil
}
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
metadataBytes, err := json.Marshal(metadata)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
}
tx, err := db.pool.Begin(ctx)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
tag, err := tx.Exec(ctx, `
update thoughts
set content = $2,
metadata = $3::jsonb,
project_id = $4,
updated_at = now()
where guid = $1
`, id, content, metadataBytes, projectID)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
}
if tag.RowsAffected() == 0 {
return thoughttypes.Thought{}, pgx.ErrNoRows
}
if len(embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
}
return db.GetThought(ctx, id)
}
func (db *DB) DeleteThought(ctx context.Context, id uuid.UUID) error {
tag, err := db.pool.Exec(ctx, `delete from thoughts where guid = $1`, id)
if err != nil {
return fmt.Errorf("delete thought: %w", err)
}
if tag.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func (db *DB) ArchiveThought(ctx context.Context, id uuid.UUID) error {
tag, err := db.pool.Exec(ctx, `update thoughts set archived_at = now(), updated_at = now() where guid = $1`, id)
if err != nil {
return fmt.Errorf("archive thought: %w", err)
}
if tag.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit int, days int) ([]thoughttypes.Thought, error) {
filter := thoughttypes.ListFilter{
Limit: limit,
ProjectID: projectID,
Days: days,
IncludeArchived: false,
}
return db.ListThoughts(ctx, filter)
}
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
conditions := []string{
"t.archived_at is null",
"1 - (e.embedding <=> $1) > $2",
"e.model = $3",
}
if projectID != nil {
args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
}
if excludeID != nil {
args = append(args, *excludeID)
conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args)))
}
args = append(args, limit)
query := `
select t.guid, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
from thoughts t
join embeddings e on e.thought_id = t.guid
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
order by e.embedding <=> $1
limit $%d`, len(args))
rows, err := db.pool.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("search similar thoughts: %w", err)
}
defer rows.Close()
results := make([]thoughttypes.SearchResult, 0, limit)
for rows.Next() {
var result thoughttypes.SearchResult
var metadataBytes []byte
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
return nil, fmt.Errorf("scan similar thought: %w", err)
}
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
return nil, fmt.Errorf("decode similar thought metadata: %w", err)
}
results = append(results, result)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate similar thoughts: %w", err)
}
return results, nil
}
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
type kv struct {
key string
count int
}
pairs := make([]kv, 0, len(in))
for key, count := range in {
if strings.TrimSpace(key) == "" {
continue
}
pairs = append(pairs, kv{key: key, count: count})
}
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].count == pairs[j].count {
return pairs[i].key < pairs[j].key
}
return pairs[i].count > pairs[j].count
})
if limit > 0 && len(pairs) > limit {
pairs = pairs[:limit]
}
out := make([]thoughttypes.KeyCount, 0, len(pairs))
for _, pair := range pairs {
out = append(out, thoughttypes.KeyCount{Key: pair.key, Count: pair.count})
}
return out
}