From cebef3a07ccd3e959dcaa4a325a005d312cd1573 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 25 Mar 2026 16:25:41 +0200 Subject: [PATCH] feat(embeddings): add embedding model support and related changes * Introduced EmbeddingModel method in Client and Provider interfaces * Updated InsertThought and SearchThoughts methods to handle embedding models * Created embeddings table and updated match_thoughts function for model filtering * Removed embedding column from thoughts table * Adjusted permissions for new embeddings table --- internal/ai/compat/client.go | 4 + internal/ai/provider.go | 1 + internal/app/app.go | 2 +- internal/mcpserver/schema.go | 47 +++++++++++ internal/mcpserver/schema_test.go | 42 ++++++++++ internal/mcpserver/server.go | 34 ++++---- internal/store/db.go | 34 ++------ internal/store/thoughts.go | 102 ++++++++++++++++------- internal/tools/capture.go | 2 +- internal/tools/context.go | 2 +- internal/tools/links.go | 2 +- internal/tools/recall.go | 2 +- internal/tools/search.go | 5 +- internal/tools/summarize.go | 2 +- internal/tools/update.go | 4 +- migrations/006_rls_and_grants.sql | 5 -- migrations/007_embeddings_table.sql | 16 ++++ migrations/008_update_match_thoughts.sql | 34 ++++++++ migrations/009_rls_and_grants.sql | 7 ++ 19 files changed, 259 insertions(+), 88 deletions(-) create mode 100644 internal/mcpserver/schema.go create mode 100644 internal/mcpserver/schema_test.go delete mode 100644 migrations/006_rls_and_grants.sql create mode 100644 migrations/007_embeddings_table.sql create mode 100644 migrations/008_update_match_thoughts.sql create mode 100644 migrations/009_rls_and_grants.sql diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index d03ee38..d24b011 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -208,6 +208,10 @@ func (c *Client) Name() string { return c.name } +func (c *Client) EmbeddingModel() string { + return c.embeddingModel +} + func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error { body, err := json.Marshal(requestBody) if err != nil { diff --git a/internal/ai/provider.go b/internal/ai/provider.go index ac35676..e547757 100644 --- a/internal/ai/provider.go +++ b/internal/ai/provider.go @@ -11,4 +11,5 @@ type Provider interface { ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) Name() string + EmbeddingModel() string } diff --git a/internal/app/app.go b/internal/app/app.go index 99473bc..55317a4 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -40,7 +40,7 @@ func Run(ctx context.Context, configPath string) error { } defer db.Close() - if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil { + if err := db.VerifyRequirements(ctx); err != nil { return err } diff --git a/internal/mcpserver/schema.go b/internal/mcpserver/schema.go new file mode 100644 index 0000000..9b19f5b --- /dev/null +++ b/internal/mcpserver/schema.go @@ -0,0 +1,47 @@ +package mcpserver + +import ( + "context" + "fmt" + "reflect" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var toolSchemaOptions = &jsonschema.ForOptions{ + TypeSchemas: map[reflect.Type]*jsonschema.Schema{ + reflect.TypeFor[uuid.UUID](): { + Type: "string", + Format: "uuid", + }, + }, +} + +func addTool[In any, Out any](server *mcp.Server, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) { + if err := setToolSchemas[In, Out](tool); err != nil { + panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err)) + } + mcp.AddTool(server, tool, handler) +} + +func setToolSchemas[In any, Out any](tool *mcp.Tool) error { + if tool.InputSchema == nil { + inputSchema, err := jsonschema.For[In](toolSchemaOptions) + if err != nil { + return fmt.Errorf("infer input schema: %w", err) + } + tool.InputSchema = inputSchema + } + + if tool.OutputSchema == nil { + outputSchema, err := jsonschema.For[Out](toolSchemaOptions) + if err != nil { + return fmt.Errorf("infer output schema: %w", err) + } + tool.OutputSchema = outputSchema + } + + return nil +} diff --git a/internal/mcpserver/schema_test.go b/internal/mcpserver/schema_test.go new file mode 100644 index 0000000..ec37986 --- /dev/null +++ b/internal/mcpserver/schema_test.go @@ -0,0 +1,42 @@ +package mcpserver + +import ( + "testing" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/tools" +) + +func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) { + tool := &mcp.Tool{Name: "list_thoughts"} + + if err := setToolSchemas[tools.ListInput, tools.ListOutput](tool); err != nil { + t.Fatalf("set tool schemas: %v", err) + } + + schema, ok := tool.OutputSchema.(*jsonschema.Schema) + if !ok { + t.Fatalf("output schema type = %T, want *jsonschema.Schema", tool.OutputSchema) + } + + thoughtsSchema := schema.Properties["thoughts"] + if thoughtsSchema == nil { + t.Fatal("missing thoughts schema") + } + if thoughtsSchema.Items == nil { + t.Fatal("missing thoughts item schema") + } + + idSchema := thoughtsSchema.Items.Properties["id"] + if idSchema == nil { + t.Fatal("missing id schema") + } + if idSchema.Type != "string" { + t.Fatalf("id schema type = %q, want %q", idSchema.Type, "string") + } + if idSchema.Format != "uuid" { + t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid") + } +} diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index ce0a513..7f41b40 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -32,87 +32,87 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler { Version: cfg.Version, }, nil) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "capture_thought", Description: "Store a thought with generated embeddings and extracted metadata.", }, toolSet.Capture.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "search_thoughts", Description: "Search stored thoughts by semantic similarity.", }, toolSet.Search.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "list_thoughts", Description: "List recent thoughts with optional metadata filters.", }, toolSet.List.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "thought_stats", Description: "Get counts and top metadata buckets across stored thoughts.", }, toolSet.Stats.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "get_thought", Description: "Retrieve a full thought by id.", }, toolSet.Get.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "update_thought", Description: "Update thought content or merge metadata.", }, toolSet.Update.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "delete_thought", Description: "Hard-delete a thought by id.", }, toolSet.Delete.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "archive_thought", Description: "Archive a thought so it is hidden from default search and listing.", }, toolSet.Archive.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "create_project", Description: "Create a named project container for thoughts.", }, toolSet.Projects.Create) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "list_projects", Description: "List projects and their current thought counts.", }, toolSet.Projects.List) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "set_active_project", Description: "Set the active project for the current MCP session.", }, toolSet.Projects.SetActive) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "get_active_project", Description: "Return the active project for the current MCP session.", }, toolSet.Projects.GetActive) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "get_project_context", Description: "Get recent and semantic context for a project.", }, toolSet.Context.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "recall_context", Description: "Recall semantically relevant and recent context.", }, toolSet.Recall.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "summarize_thoughts", Description: "Summarize a filtered or searched set of thoughts.", }, toolSet.Summarize.Handle) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "link_thoughts", Description: "Create a typed relationship between two thoughts.", }, toolSet.Links.Link) - mcp.AddTool(server, &mcp.Tool{ + addTool(server, &mcp.Tool{ Name: "related_thoughts", Description: "Retrieve explicit links and semantic neighbors for a thought.", }, toolSet.Links.Related) diff --git a/internal/store/db.go b/internal/store/db.go index 8235d53..c7b61cd 100644 --- a/internal/store/db.go +++ b/internal/store/db.go @@ -3,7 +3,6 @@ package store import ( "context" "fmt" - "regexp" "time" "github.com/jackc/pgx/v5" @@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error { return db.Ping(readyCtx) } -func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error { +func (db *DB) VerifyRequirements(ctx context.Context) error { var hasVector bool if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil { return fmt.Errorf("verify vector extension: %w", err) @@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error { return fmt.Errorf("match_thoughts function is missing") } - var embeddingType string - err := db.pool.QueryRow(ctx, ` - select format_type(a.atttypid, a.atttypmod) - from pg_attribute a - join pg_class c on c.oid = a.attrelid - join pg_namespace n on n.oid = c.relnamespace - where n.nspname = 'public' - and c.relname = 'thoughts' - and a.attname = 'embedding' - and not a.attisdropped - `).Scan(&embeddingType) - if err != nil { - return fmt.Errorf("verify thoughts.embedding type: %w", err) + var hasEmbeddings bool + if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil { + return fmt.Errorf("verify embeddings table: %w", err) } - - re := regexp.MustCompile(`vector\((\d+)\)`) - matches := re.FindStringSubmatch(embeddingType) - if len(matches) != 2 { - return fmt.Errorf("unexpected embedding type %q", embeddingType) - } - - var actualDimensions int - if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil { - return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err) - } - if actualDimensions != dimensions { - return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions) + if !hasEmbeddings { + return fmt.Errorf("embeddings table is missing — run migrations") } return nil diff --git a/internal/store/thoughts.go b/internal/store/thoughts.go index 36d54cb..584f172 100644 --- a/internal/store/thoughts.go +++ b/internal/store/thoughts.go @@ -15,27 +15,51 @@ import ( thoughttypes "git.warky.dev/wdevs/amcs/internal/types" ) -func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) { +func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) { metadata, err := json.Marshal(thought.Metadata) if err != nil { return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err) } - row := db.pool.QueryRow(ctx, ` - insert into thoughts (content, embedding, metadata, project_id) - values ($1, $2, $3::jsonb, $4) + tx, err := db.pool.Begin(ctx) + if err != nil { + return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback(ctx) + + row := tx.QueryRow(ctx, ` + insert into thoughts (content, metadata, project_id) + values ($1, $2::jsonb, $3) returning id, created_at, updated_at - `, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID) + `, thought.Content, metadata, thought.ProjectID) created := thought + created.Embedding = nil if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil { return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err) } + if len(thought.Embedding) > 0 && embeddingModel != "" { + if _, err := tx.Exec(ctx, ` + insert into embeddings (thought_id, model, dim, embedding) + values ($1, $2, $3, $4) + on conflict (thought_id, model) do update + set embedding = excluded.embedding, + dim = excluded.dim, + updated_at = now() + `, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err) + } + } + + if err := tx.Commit(ctx); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err) + } + return created, nil } -func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) { +func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) { filterJSON, err := json.Marshal(filter) if err != nil { return nil, fmt.Errorf("marshal search filter: %w", err) @@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold rows, err := db.pool.Query(ctx, ` select id, content, metadata, similarity, created_at - from match_thoughts($1, $2, $3, $4::jsonb) - `, pgvector.NewVector(embedding), threshold, limit, filterJSON) + from match_thoughts($1, $2, $3, $4::jsonb, $5) + `, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel) if err != nil { return nil, fmt.Errorf("search thoughts: %w", err) } @@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) { func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) { row := db.pool.QueryRow(ctx, ` - select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at + select id, content, metadata, project_id, archived_at, created_at, updated_at from thoughts where id = $1 `, id) var thought thoughttypes.Thought - var embedding pgvector.Vector var metadataBytes []byte - if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { + if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { if err == pgx.ErrNoRows { return thoughttypes.Thought{}, err } @@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil { return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err) } - thought.Embedding = embedding.Slice() return thought, nil } -func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) { +func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) { metadataBytes, err := json.Marshal(metadata) if err != nil { return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err) } - tag, err := db.pool.Exec(ctx, ` + tx, err := db.pool.Begin(ctx) + if err != nil { + return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback(ctx) + + tag, err := tx.Exec(ctx, ` update thoughts - set content = $2, - embedding = $3, - metadata = $4::jsonb, - project_id = $5, + set content = $2, + metadata = $3::jsonb, + project_id = $4, updated_at = now() where id = $1 - `, id, content, pgvector.NewVector(embedding), metadataBytes, projectID) + `, id, content, metadataBytes, projectID) if err != nil { return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err) } @@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e return thoughttypes.Thought{}, pgx.ErrNoRows } + if len(embedding) > 0 && embeddingModel != "" { + if _, err := tx.Exec(ctx, ` + insert into embeddings (thought_id, model, dim, embedding) + values ($1, $2, $3, $4) + on conflict (thought_id, model) do update + set embedding = excluded.embedding, + dim = excluded.dim, + updated_at = now() + `, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err) + } + } + + if err := tx.Commit(ctx); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err) + } + return db.GetThought(ctx, id) } @@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in return db.ListThoughts(ctx, filter) } -func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) { - args := []any{pgvector.NewVector(embedding), threshold} +func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) { + args := []any{pgvector.NewVector(embedding), threshold, embeddingModel} conditions := []string{ - "archived_at is null", - "1 - (embedding <=> $1) > $2", + "t.archived_at is null", + "1 - (e.embedding <=> $1) > $2", + "e.model = $3", } if projectID != nil { args = append(args, *projectID) - conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) + conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args))) } if excludeID != nil { args = append(args, *excludeID) - conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args))) + conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args))) } args = append(args, limit) query := ` - select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at - from thoughts + select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at + from thoughts t + join embeddings e on e.thought_id = t.id where ` + strings.Join(conditions, " and ") + fmt.Sprintf(` - order by embedding <=> $1 + order by e.embedding <=> $1 limit $%d`, len(args)) rows, err := db.pool.Query(ctx, query, args...) diff --git a/internal/tools/capture.go b/internal/tools/capture.go index 9c1b8fe..91d3e9a 100644 --- a/internal/tools/capture.go +++ b/internal/tools/capture.go @@ -83,7 +83,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C thought.ProjectID = &project.ID } - created, err := t.store.InsertThought(ctx, thought) + created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel()) if err != nil { return nil, CaptureOutput{}, err } diff --git a/internal/tools/context.go b/internal/tools/context.go index 3ca9e94..69dde22 100644 --- a/internal/tools/context.go +++ b/internal/tools/context.go @@ -76,7 +76,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P if err != nil { return nil, ProjectContextOutput{}, err } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil) + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil) if err != nil { return nil, ProjectContextOutput{}, err } diff --git a/internal/tools/links.go b/internal/tools/links.go index 1940ea3..41c27e9 100644 --- a/internal/tools/links.go +++ b/internal/tools/links.go @@ -121,7 +121,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela if err != nil { return nil, RelatedOutput{}, err } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID) + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID) if err != nil { return nil, RelatedOutput{}, err } diff --git a/internal/tools/recall.go b/internal/tools/recall.go index a7cddc0..d420caf 100644 --- a/internal/tools/recall.go +++ b/internal/tools/recall.go @@ -58,7 +58,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re projectID = &project.ID } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil) if err != nil { return nil, RecallOutput{}, err } diff --git a/internal/tools/search.go b/internal/tools/search.go index 56aabb4..4ad0097 100644 --- a/internal/tools/search.go +++ b/internal/tools/search.go @@ -54,12 +54,13 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se return nil, SearchOutput{}, err } + model := t.provider.EmbeddingModel() var results []thoughttypes.SearchResult if project != nil { - results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil) + results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil) _ = t.store.TouchProject(ctx, project.ID) } else { - results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{}) + results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{}) } if err != nil { return nil, SearchOutput{}, err diff --git a/internal/tools/summarize.go b/internal/tools/summarize.go index 11a7912..0417f5b 100644 --- a/internal/tools/summarize.go +++ b/internal/tools/summarize.go @@ -56,7 +56,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in if project != nil { projectID = &project.ID } - results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) + results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil) if err != nil { return nil, SummarizeOutput{}, err } diff --git a/internal/tools/update.go b/internal/tools/update.go index 65bf75f..dc7d4a4 100644 --- a/internal/tools/update.go +++ b/internal/tools/update.go @@ -48,7 +48,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda } content := current.Content - embedding := current.Embedding + var embedding []float32 mergedMetadata := current.Metadata projectID := current.ProjectID @@ -79,7 +79,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda projectID = &project.ID } - updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID) + updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID) if err != nil { return nil, UpdateOutput{}, err } diff --git a/migrations/006_rls_and_grants.sql b/migrations/006_rls_and_grants.sql deleted file mode 100644 index c94fdb5..0000000 --- a/migrations/006_rls_and_grants.sql +++ /dev/null @@ -1,5 +0,0 @@ --- Grant these permissions to the database role used by the application. --- Replace amcs_user with the actual role in your deployment before applying. -grant select, insert, update, delete on table public.thoughts to amcs_user; -grant select, insert, update, delete on table public.projects to amcs_user; -grant select, insert, update, delete on table public.thought_links to amcs_user; diff --git a/migrations/007_embeddings_table.sql b/migrations/007_embeddings_table.sql new file mode 100644 index 0000000..b5242fc --- /dev/null +++ b/migrations/007_embeddings_table.sql @@ -0,0 +1,16 @@ +create table if not exists embeddings ( + id bigserial primary key, + guid uuid not null default gen_random_uuid(), + thought_id uuid not null references thoughts(id) on delete cascade, + model text not null, + dim int not null, + embedding vector not null, + created_at timestamptz default now(), + updated_at timestamptz default now(), + constraint embeddings_guid_unique unique (guid), + constraint embeddings_thought_model_unique unique (thought_id, model) +); + +create index if not exists embeddings_thought_id_idx on embeddings (thought_id); + +alter table thoughts drop column if exists embedding; diff --git a/migrations/008_update_match_thoughts.sql b/migrations/008_update_match_thoughts.sql new file mode 100644 index 0000000..0e2f470 --- /dev/null +++ b/migrations/008_update_match_thoughts.sql @@ -0,0 +1,34 @@ +create or replace function match_thoughts( + query_embedding vector, + match_threshold float default 0.7, + match_count int default 10, + filter jsonb default '{}'::jsonb, + embedding_model text default '' +) +returns table ( + id uuid, + content text, + metadata jsonb, + similarity float, + created_at timestamptz +) +language plpgsql +as $$ +begin + return query + select + t.id, + t.content, + t.metadata, + 1 - (e.embedding <=> query_embedding) as similarity, + t.created_at + from thoughts t + join embeddings e on e.thought_id = t.id + where 1 - (e.embedding <=> query_embedding) > match_threshold + and t.archived_at is null + and (embedding_model = '' or e.model = embedding_model) + and (filter = '{}'::jsonb or t.metadata @> filter) + order by e.embedding <=> query_embedding + limit match_count; +end; +$$; diff --git a/migrations/009_rls_and_grants.sql b/migrations/009_rls_and_grants.sql new file mode 100644 index 0000000..2365637 --- /dev/null +++ b/migrations/009_rls_and_grants.sql @@ -0,0 +1,7 @@ +-- Grant these permissions to the database role used by the application. +-- Replace amcs_user with the actual role in your deployment before applying. +grant ALL ON TABLE public.thoughts to amcs; +grant ALL ON TABLE public.projects to amcs; +grant ALL ON TABLE public.thought_links to amcs; +grant ALL ON TABLE public.embeddings to amcs; +GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO amcs; \ No newline at end of file