feat(tools): implement CRUD operations for thoughts and projects
* Add tools for creating, retrieving, updating, and deleting thoughts. * Implement project management tools for creating and listing projects. * Introduce linking functionality between thoughts. * Add search and recall capabilities for thoughts based on semantic queries. * Implement statistics and summarization tools for thought analysis. * Create database migrations for thoughts, projects, and links. * Add helper functions for UUID parsing and project resolution.
This commit is contained in:
118
internal/store/db.go
Normal file
118
internal/store/db.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
pgxvec "github.com/pgvector/pgvector-go/pgx"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse database config: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = cfg.MaxConns
|
||||
poolConfig.MinConns = cfg.MinConns
|
||||
poolConfig.MaxConnLifetime = cfg.MaxConnLifetime
|
||||
poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime
|
||||
poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
return pgxvec.RegisterTypes(ctx, conn)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create database pool: %w", err)
|
||||
}
|
||||
|
||||
db := &DB{pool: pool}
|
||||
if err := db.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (db *DB) Close() {
|
||||
if db == nil || db.pool == nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.pool.Close()
|
||||
}
|
||||
|
||||
func (db *DB) Ping(ctx context.Context) error {
|
||||
if err := db.pool.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Ready(ctx context.Context) error {
|
||||
readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return db.Ping(readyCtx)
|
||||
}
|
||||
|
||||
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
var hasVector bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
||||
return fmt.Errorf("verify vector extension: %w", err)
|
||||
}
|
||||
if !hasVector {
|
||||
return fmt.Errorf("vector extension is not installed")
|
||||
}
|
||||
|
||||
var hasMatchThoughts bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil {
|
||||
return fmt.Errorf("verify match_thoughts function: %w", err)
|
||||
}
|
||||
if !hasMatchThoughts {
|
||||
return fmt.Errorf("match_thoughts function is missing")
|
||||
}
|
||||
|
||||
var embeddingType string
|
||||
err := db.pool.QueryRow(ctx, `
|
||||
select format_type(a.atttypid, a.atttypmod)
|
||||
from pg_attribute a
|
||||
join pg_class c on c.oid = a.attrelid
|
||||
join pg_namespace n on n.oid = c.relnamespace
|
||||
where n.nspname = 'public'
|
||||
and c.relname = 'thoughts'
|
||||
and a.attname = 'embedding'
|
||||
and not a.attisdropped
|
||||
`).Scan(&embeddingType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify thoughts.embedding type: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`vector\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(embeddingType)
|
||||
if len(matches) != 2 {
|
||||
return fmt.Errorf("unexpected embedding type %q", embeddingType)
|
||||
}
|
||||
|
||||
var actualDimensions int
|
||||
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
|
||||
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
|
||||
}
|
||||
if actualDimensions != dimensions {
|
||||
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
69
internal/store/links.go
Normal file
69
internal/store/links.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertLink(ctx context.Context, link thoughttypes.ThoughtLink) error {
|
||||
_, err := db.pool.Exec(ctx, `
|
||||
insert into thought_links (from_id, to_id, relation)
|
||||
values ($1, $2, $3)
|
||||
`, link.FromID, link.ToID, link.Relation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert link: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) LinkedThoughts(ctx context.Context, thoughtID uuid.UUID) ([]thoughttypes.LinkedThought, error) {
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'outgoing' as direction, l.created_at
|
||||
from thought_links l
|
||||
join thoughts t on t.id = l.to_id
|
||||
where l.from_id = $1
|
||||
union all
|
||||
select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'incoming' as direction, l.created_at
|
||||
from thought_links l
|
||||
join thoughts t on t.id = l.from_id
|
||||
where l.to_id = $1
|
||||
order by created_at desc
|
||||
`, thoughtID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query linked thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
links := make([]thoughttypes.LinkedThought, 0)
|
||||
for rows.Next() {
|
||||
var linked thoughttypes.LinkedThought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(
|
||||
&linked.Thought.ID,
|
||||
&linked.Thought.Content,
|
||||
&metadataBytes,
|
||||
&linked.Thought.ProjectID,
|
||||
&linked.Thought.ArchivedAt,
|
||||
&linked.Thought.CreatedAt,
|
||||
&linked.Thought.UpdatedAt,
|
||||
&linked.Relation,
|
||||
&linked.Direction,
|
||||
&linked.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan linked thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &linked.Thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode linked thought metadata: %w", err)
|
||||
}
|
||||
links = append(links, linked)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate linked thoughts: %w", err)
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
90
internal/store/projects.go
Normal file
90
internal/store/projects.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) CreateProject(ctx context.Context, name, description string) (thoughttypes.Project, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into projects (name, description)
|
||||
values ($1, $2)
|
||||
returning id, name, description, created_at, last_active_at
|
||||
`, name, description)
|
||||
|
||||
var project thoughttypes.Project
|
||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||
return thoughttypes.Project{}, fmt.Errorf("create project: %w", err)
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
||||
var row pgx.Row
|
||||
if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil {
|
||||
row = db.pool.QueryRow(ctx, `
|
||||
select id, name, description, created_at, last_active_at
|
||||
from projects
|
||||
where id = $1
|
||||
`, parsedID)
|
||||
} else {
|
||||
row = db.pool.QueryRow(ctx, `
|
||||
select id, name, description, created_at, last_active_at
|
||||
from projects
|
||||
where name = $1
|
||||
`, strings.TrimSpace(nameOrID))
|
||||
}
|
||||
|
||||
var project thoughttypes.Project
|
||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Project{}, err
|
||||
}
|
||||
return thoughttypes.Project{}, fmt.Errorf("get project: %w", err)
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListProjects(ctx context.Context) ([]thoughttypes.ProjectSummary, error) {
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select p.id, p.name, p.description, p.created_at, p.last_active_at, count(t.id) as thought_count
|
||||
from projects p
|
||||
left join thoughts t on t.project_id = p.id and t.archived_at is null
|
||||
group by p.id
|
||||
order by p.last_active_at desc, p.created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list projects: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
projects := make([]thoughttypes.ProjectSummary, 0)
|
||||
for rows.Next() {
|
||||
var project thoughttypes.ProjectSummary
|
||||
if err := rows.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt, &project.ThoughtCount); err != nil {
|
||||
return nil, fmt.Errorf("scan project summary: %w", err)
|
||||
}
|
||||
projects = append(projects, project)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate projects: %w", err)
|
||||
}
|
||||
return projects, nil
|
||||
}
|
||||
|
||||
func (db *DB) TouchProject(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `update projects set last_active_at = now() where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("touch project: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
345
internal/store/thoughts.go
Normal file
345
internal/store/thoughts.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
|
||||
metadata, err := json.Marshal(thought.Metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into thoughts (content, embedding, metadata, project_id)
|
||||
values ($1, $2, $3::jsonb, $4)
|
||||
returning id, created_at, updated_at
|
||||
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
|
||||
|
||||
created := thought
|
||||
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal search filter: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select id, content, metadata, similarity, created_at
|
||||
from match_thoughts($1, $2, $3, $4::jsonb)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan search result: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode search metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate search results: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListThoughts(ctx context.Context, filter thoughttypes.ListFilter) ([]thoughttypes.Thought, error) {
|
||||
args := make([]any, 0, 6)
|
||||
conditions := []string{}
|
||||
|
||||
if !filter.IncludeArchived {
|
||||
conditions = append(conditions, "archived_at is null")
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Type); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->>'type' = $%d", len(args)))
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Topic); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->'topics' ? $%d", len(args)))
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Person); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->'people' ? $%d", len(args)))
|
||||
}
|
||||
if filter.Days > 0 {
|
||||
args = append(args, time.Now().Add(-time.Duration(filter.Days)*24*time.Hour))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)))
|
||||
}
|
||||
if filter.ProjectID != nil {
|
||||
args = append(args, *filter.ProjectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
}
|
||||
|
||||
query := `
|
||||
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
`
|
||||
if len(conditions) > 0 {
|
||||
query += " where " + strings.Join(conditions, " and ")
|
||||
}
|
||||
|
||||
args = append(args, filter.Limit)
|
||||
query += fmt.Sprintf(" order by created_at desc limit $%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
thoughts := make([]thoughttypes.Thought, 0, filter.Limit)
|
||||
for rows.Next() {
|
||||
var thought thoughttypes.Thought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan listed thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode listed metadata: %w", err)
|
||||
}
|
||||
thoughts = append(thoughts, thought)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate listed thoughts: %w", err)
|
||||
}
|
||||
|
||||
return thoughts, nil
|
||||
}
|
||||
|
||||
func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
||||
var total int
|
||||
if err := db.pool.QueryRow(ctx, `select count(*) from thoughts where archived_at is null`).Scan(&total); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("count thoughts: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.pool.Query(ctx, `select metadata from thoughts where archived_at is null`)
|
||||
if err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("query stats metadata: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats := thoughttypes.ThoughtStats{
|
||||
TotalCount: total,
|
||||
TypeCounts: map[string]int{},
|
||||
}
|
||||
topics := map[string]int{}
|
||||
people := map[string]int{}
|
||||
|
||||
for rows.Next() {
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&metadataBytes); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("scan stats metadata: %w", err)
|
||||
}
|
||||
|
||||
var metadata thoughttypes.ThoughtMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("decode stats metadata: %w", err)
|
||||
}
|
||||
|
||||
stats.TypeCounts[metadata.Type]++
|
||||
for _, topic := range metadata.Topics {
|
||||
topics[topic]++
|
||||
}
|
||||
for _, person := range metadata.People {
|
||||
people[person]++
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("iterate stats metadata: %w", err)
|
||||
}
|
||||
|
||||
stats.TopTopics = topCounts(topics, 10)
|
||||
stats.TopPeople = topCounts(people, 10)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
where id = $1
|
||||
`, id)
|
||||
|
||||
var thought thoughttypes.Thought
|
||||
var embedding pgvector.Vector
|
||||
var metadataBytes []byte
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Thought{}, err
|
||||
}
|
||||
return thoughttypes.Thought{}, fmt.Errorf("get thought: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
||||
}
|
||||
thought.Embedding = embedding.Slice()
|
||||
|
||||
return thought, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
||||
}
|
||||
|
||||
tag, err := db.pool.Exec(ctx, `
|
||||
update thoughts
|
||||
set content = $2,
|
||||
embedding = $3,
|
||||
metadata = $4::jsonb,
|
||||
project_id = $5,
|
||||
updated_at = now()
|
||||
where id = $1
|
||||
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return thoughttypes.Thought{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
return db.GetThought(ctx, id)
|
||||
}
|
||||
|
||||
func (db *DB) DeleteThought(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `delete from thoughts where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ArchiveThought(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `update thoughts set archived_at = now(), updated_at = now() where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("archive thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit int, days int) ([]thoughttypes.Thought, error) {
|
||||
filter := thoughttypes.ListFilter{
|
||||
Limit: limit,
|
||||
ProjectID: projectID,
|
||||
Days: days,
|
||||
IncludeArchived: false,
|
||||
}
|
||||
return db.ListThoughts(ctx, filter)
|
||||
}
|
||||
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold}
|
||||
conditions := []string{
|
||||
"archived_at is null",
|
||||
"1 - (embedding <=> $1) > $2",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
|
||||
from thoughts
|
||||
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
||||
order by embedding <=> $1
|
||||
limit $%d`, len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search similar thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan similar thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode similar thought metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate similar thoughts: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||
type kv struct {
|
||||
key string
|
||||
count int
|
||||
}
|
||||
|
||||
pairs := make([]kv, 0, len(in))
|
||||
for key, count := range in {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
pairs = append(pairs, kv{key: key, count: count})
|
||||
}
|
||||
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].count == pairs[j].count {
|
||||
return pairs[i].key < pairs[j].key
|
||||
}
|
||||
return pairs[i].count > pairs[j].count
|
||||
})
|
||||
|
||||
if limit > 0 && len(pairs) > limit {
|
||||
pairs = pairs[:limit]
|
||||
}
|
||||
|
||||
out := make([]thoughttypes.KeyCount, 0, len(pairs))
|
||||
for _, pair := range pairs {
|
||||
out = append(out, thoughttypes.KeyCount{Key: pair.key, Count: pair.count})
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user