feat(embeddings): add embedding model support and related changes
* Introduced EmbeddingModel method in Client and Provider interfaces * Updated InsertThought and SearchThoughts methods to handle embedding models * Created embeddings table and updated match_thoughts function for model filtering * Removed embedding column from thoughts table * Adjusted permissions for new embeddings table
This commit is contained in:
@@ -208,6 +208,10 @@ func (c *Client) Name() string {
|
|||||||
return c.name
|
return c.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) EmbeddingModel() string {
|
||||||
|
return c.embeddingModel
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||||
body, err := json.Marshal(requestBody)
|
body, err := json.Marshal(requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,4 +11,5 @@ type Provider interface {
|
|||||||
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
||||||
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||||
Name() string
|
Name() string
|
||||||
|
EmbeddingModel() string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil {
|
if err := db.VerifyRequirements(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
47
internal/mcpserver/schema.go
Normal file
47
internal/mcpserver/schema.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package mcpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/google/jsonschema-go/jsonschema"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var toolSchemaOptions = &jsonschema.ForOptions{
|
||||||
|
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
||||||
|
reflect.TypeFor[uuid.UUID](): {
|
||||||
|
Type: "string",
|
||||||
|
Format: "uuid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func addTool[In any, Out any](server *mcp.Server, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
|
||||||
|
if err := setToolSchemas[In, Out](tool); err != nil {
|
||||||
|
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
|
||||||
|
}
|
||||||
|
mcp.AddTool(server, tool, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
||||||
|
if tool.InputSchema == nil {
|
||||||
|
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("infer input schema: %w", err)
|
||||||
|
}
|
||||||
|
tool.InputSchema = inputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool.OutputSchema == nil {
|
||||||
|
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("infer output schema: %w", err)
|
||||||
|
}
|
||||||
|
tool.OutputSchema = outputSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
42
internal/mcpserver/schema_test.go
Normal file
42
internal/mcpserver/schema_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package mcpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/jsonschema-go/jsonschema"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
|
||||||
|
tool := &mcp.Tool{Name: "list_thoughts"}
|
||||||
|
|
||||||
|
if err := setToolSchemas[tools.ListInput, tools.ListOutput](tool); err != nil {
|
||||||
|
t.Fatalf("set tool schemas: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema, ok := tool.OutputSchema.(*jsonschema.Schema)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("output schema type = %T, want *jsonschema.Schema", tool.OutputSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
thoughtsSchema := schema.Properties["thoughts"]
|
||||||
|
if thoughtsSchema == nil {
|
||||||
|
t.Fatal("missing thoughts schema")
|
||||||
|
}
|
||||||
|
if thoughtsSchema.Items == nil {
|
||||||
|
t.Fatal("missing thoughts item schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
idSchema := thoughtsSchema.Items.Properties["id"]
|
||||||
|
if idSchema == nil {
|
||||||
|
t.Fatal("missing id schema")
|
||||||
|
}
|
||||||
|
if idSchema.Type != "string" {
|
||||||
|
t.Fatalf("id schema type = %q, want %q", idSchema.Type, "string")
|
||||||
|
}
|
||||||
|
if idSchema.Format != "uuid" {
|
||||||
|
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,87 +32,87 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
|||||||
Version: cfg.Version,
|
Version: cfg.Version,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "capture_thought",
|
Name: "capture_thought",
|
||||||
Description: "Store a thought with generated embeddings and extracted metadata.",
|
Description: "Store a thought with generated embeddings and extracted metadata.",
|
||||||
}, toolSet.Capture.Handle)
|
}, toolSet.Capture.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "search_thoughts",
|
Name: "search_thoughts",
|
||||||
Description: "Search stored thoughts by semantic similarity.",
|
Description: "Search stored thoughts by semantic similarity.",
|
||||||
}, toolSet.Search.Handle)
|
}, toolSet.Search.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "list_thoughts",
|
Name: "list_thoughts",
|
||||||
Description: "List recent thoughts with optional metadata filters.",
|
Description: "List recent thoughts with optional metadata filters.",
|
||||||
}, toolSet.List.Handle)
|
}, toolSet.List.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "thought_stats",
|
Name: "thought_stats",
|
||||||
Description: "Get counts and top metadata buckets across stored thoughts.",
|
Description: "Get counts and top metadata buckets across stored thoughts.",
|
||||||
}, toolSet.Stats.Handle)
|
}, toolSet.Stats.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "get_thought",
|
Name: "get_thought",
|
||||||
Description: "Retrieve a full thought by id.",
|
Description: "Retrieve a full thought by id.",
|
||||||
}, toolSet.Get.Handle)
|
}, toolSet.Get.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "update_thought",
|
Name: "update_thought",
|
||||||
Description: "Update thought content or merge metadata.",
|
Description: "Update thought content or merge metadata.",
|
||||||
}, toolSet.Update.Handle)
|
}, toolSet.Update.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "delete_thought",
|
Name: "delete_thought",
|
||||||
Description: "Hard-delete a thought by id.",
|
Description: "Hard-delete a thought by id.",
|
||||||
}, toolSet.Delete.Handle)
|
}, toolSet.Delete.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "archive_thought",
|
Name: "archive_thought",
|
||||||
Description: "Archive a thought so it is hidden from default search and listing.",
|
Description: "Archive a thought so it is hidden from default search and listing.",
|
||||||
}, toolSet.Archive.Handle)
|
}, toolSet.Archive.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "create_project",
|
Name: "create_project",
|
||||||
Description: "Create a named project container for thoughts.",
|
Description: "Create a named project container for thoughts.",
|
||||||
}, toolSet.Projects.Create)
|
}, toolSet.Projects.Create)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "list_projects",
|
Name: "list_projects",
|
||||||
Description: "List projects and their current thought counts.",
|
Description: "List projects and their current thought counts.",
|
||||||
}, toolSet.Projects.List)
|
}, toolSet.Projects.List)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "set_active_project",
|
Name: "set_active_project",
|
||||||
Description: "Set the active project for the current MCP session.",
|
Description: "Set the active project for the current MCP session.",
|
||||||
}, toolSet.Projects.SetActive)
|
}, toolSet.Projects.SetActive)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "get_active_project",
|
Name: "get_active_project",
|
||||||
Description: "Return the active project for the current MCP session.",
|
Description: "Return the active project for the current MCP session.",
|
||||||
}, toolSet.Projects.GetActive)
|
}, toolSet.Projects.GetActive)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "get_project_context",
|
Name: "get_project_context",
|
||||||
Description: "Get recent and semantic context for a project.",
|
Description: "Get recent and semantic context for a project.",
|
||||||
}, toolSet.Context.Handle)
|
}, toolSet.Context.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "recall_context",
|
Name: "recall_context",
|
||||||
Description: "Recall semantically relevant and recent context.",
|
Description: "Recall semantically relevant and recent context.",
|
||||||
}, toolSet.Recall.Handle)
|
}, toolSet.Recall.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "summarize_thoughts",
|
Name: "summarize_thoughts",
|
||||||
Description: "Summarize a filtered or searched set of thoughts.",
|
Description: "Summarize a filtered or searched set of thoughts.",
|
||||||
}, toolSet.Summarize.Handle)
|
}, toolSet.Summarize.Handle)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "link_thoughts",
|
Name: "link_thoughts",
|
||||||
Description: "Create a typed relationship between two thoughts.",
|
Description: "Create a typed relationship between two thoughts.",
|
||||||
}, toolSet.Links.Link)
|
}, toolSet.Links.Link)
|
||||||
|
|
||||||
mcp.AddTool(server, &mcp.Tool{
|
addTool(server, &mcp.Tool{
|
||||||
Name: "related_thoughts",
|
Name: "related_thoughts",
|
||||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||||
}, toolSet.Links.Related)
|
}, toolSet.Links.Related)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package store
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
@@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error {
|
|||||||
return db.Ping(readyCtx)
|
return db.Ping(readyCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
func (db *DB) VerifyRequirements(ctx context.Context) error {
|
||||||
var hasVector bool
|
var hasVector bool
|
||||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
||||||
return fmt.Errorf("verify vector extension: %w", err)
|
return fmt.Errorf("verify vector extension: %w", err)
|
||||||
@@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
|||||||
return fmt.Errorf("match_thoughts function is missing")
|
return fmt.Errorf("match_thoughts function is missing")
|
||||||
}
|
}
|
||||||
|
|
||||||
var embeddingType string
|
var hasEmbeddings bool
|
||||||
err := db.pool.QueryRow(ctx, `
|
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
|
||||||
select format_type(a.atttypid, a.atttypmod)
|
return fmt.Errorf("verify embeddings table: %w", err)
|
||||||
from pg_attribute a
|
|
||||||
join pg_class c on c.oid = a.attrelid
|
|
||||||
join pg_namespace n on n.oid = c.relnamespace
|
|
||||||
where n.nspname = 'public'
|
|
||||||
and c.relname = 'thoughts'
|
|
||||||
and a.attname = 'embedding'
|
|
||||||
and not a.attisdropped
|
|
||||||
`).Scan(&embeddingType)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("verify thoughts.embedding type: %w", err)
|
|
||||||
}
|
}
|
||||||
|
if !hasEmbeddings {
|
||||||
re := regexp.MustCompile(`vector\((\d+)\)`)
|
return fmt.Errorf("embeddings table is missing — run migrations")
|
||||||
matches := re.FindStringSubmatch(embeddingType)
|
|
||||||
if len(matches) != 2 {
|
|
||||||
return fmt.Errorf("unexpected embedding type %q", embeddingType)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actualDimensions int
|
|
||||||
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
|
|
||||||
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
|
|
||||||
}
|
|
||||||
if actualDimensions != dimensions {
|
|
||||||
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -15,27 +15,51 @@ import (
|
|||||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
|
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
|
||||||
metadata, err := json.Marshal(thought.Metadata)
|
metadata, err := json.Marshal(thought.Metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
row := db.pool.QueryRow(ctx, `
|
tx, err := db.pool.Begin(ctx)
|
||||||
insert into thoughts (content, embedding, metadata, project_id)
|
if err != nil {
|
||||||
values ($1, $2, $3::jsonb, $4)
|
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
row := tx.QueryRow(ctx, `
|
||||||
|
insert into thoughts (content, metadata, project_id)
|
||||||
|
values ($1, $2::jsonb, $3)
|
||||||
returning id, created_at, updated_at
|
returning id, created_at, updated_at
|
||||||
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
|
`, thought.Content, metadata, thought.ProjectID)
|
||||||
|
|
||||||
created := thought
|
created := thought
|
||||||
|
created.Embedding = nil
|
||||||
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||||
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(thought.Embedding) > 0 && embeddingModel != "" {
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
insert into embeddings (thought_id, model, dim, embedding)
|
||||||
|
values ($1, $2, $3, $4)
|
||||||
|
on conflict (thought_id, model) do update
|
||||||
|
set embedding = excluded.embedding,
|
||||||
|
dim = excluded.dim,
|
||||||
|
updated_at = now()
|
||||||
|
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
|
||||||
|
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return created, nil
|
return created, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||||
filterJSON, err := json.Marshal(filter)
|
filterJSON, err := json.Marshal(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal search filter: %w", err)
|
return nil, fmt.Errorf("marshal search filter: %w", err)
|
||||||
@@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold
|
|||||||
|
|
||||||
rows, err := db.pool.Query(ctx, `
|
rows, err := db.pool.Query(ctx, `
|
||||||
select id, content, metadata, similarity, created_at
|
select id, content, metadata, similarity, created_at
|
||||||
from match_thoughts($1, $2, $3, $4::jsonb)
|
from match_thoughts($1, $2, $3, $4::jsonb, $5)
|
||||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
|
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("search thoughts: %w", err)
|
return nil, fmt.Errorf("search thoughts: %w", err)
|
||||||
}
|
}
|
||||||
@@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
|||||||
|
|
||||||
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
||||||
row := db.pool.QueryRow(ctx, `
|
row := db.pool.QueryRow(ctx, `
|
||||||
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
|
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
||||||
from thoughts
|
from thoughts
|
||||||
where id = $1
|
where id = $1
|
||||||
`, id)
|
`, id)
|
||||||
|
|
||||||
var thought thoughttypes.Thought
|
var thought thoughttypes.Thought
|
||||||
var embedding pgvector.Vector
|
|
||||||
var metadataBytes []byte
|
var metadataBytes []byte
|
||||||
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||||
if err == pgx.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
return thoughttypes.Thought{}, err
|
return thoughttypes.Thought{}, err
|
||||||
}
|
}
|
||||||
@@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though
|
|||||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||||
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
||||||
}
|
}
|
||||||
thought.Embedding = embedding.Slice()
|
|
||||||
|
|
||||||
return thought, nil
|
return thought, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||||
metadataBytes, err := json.Marshal(metadata)
|
metadataBytes, err := json.Marshal(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tag, err := db.pool.Exec(ctx, `
|
tx, err := db.pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
tag, err := tx.Exec(ctx, `
|
||||||
update thoughts
|
update thoughts
|
||||||
set content = $2,
|
set content = $2,
|
||||||
embedding = $3,
|
metadata = $3::jsonb,
|
||||||
metadata = $4::jsonb,
|
project_id = $4,
|
||||||
project_id = $5,
|
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
where id = $1
|
where id = $1
|
||||||
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
|
`, id, content, metadataBytes, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
||||||
}
|
}
|
||||||
@@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
|
|||||||
return thoughttypes.Thought{}, pgx.ErrNoRows
|
return thoughttypes.Thought{}, pgx.ErrNoRows
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(embedding) > 0 && embeddingModel != "" {
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
insert into embeddings (thought_id, model, dim, embedding)
|
||||||
|
values ($1, $2, $3, $4)
|
||||||
|
on conflict (thought_id, model) do update
|
||||||
|
set embedding = excluded.embedding,
|
||||||
|
dim = excluded.dim,
|
||||||
|
updated_at = now()
|
||||||
|
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
|
||||||
|
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return db.GetThought(ctx, id)
|
return db.GetThought(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in
|
|||||||
return db.ListThoughts(ctx, filter)
|
return db.ListThoughts(ctx, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||||
args := []any{pgvector.NewVector(embedding), threshold}
|
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
|
||||||
conditions := []string{
|
conditions := []string{
|
||||||
"archived_at is null",
|
"t.archived_at is null",
|
||||||
"1 - (embedding <=> $1) > $2",
|
"1 - (e.embedding <=> $1) > $2",
|
||||||
|
"e.model = $3",
|
||||||
}
|
}
|
||||||
if projectID != nil {
|
if projectID != nil {
|
||||||
args = append(args, *projectID)
|
args = append(args, *projectID)
|
||||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||||
}
|
}
|
||||||
if excludeID != nil {
|
if excludeID != nil {
|
||||||
args = append(args, *excludeID)
|
args = append(args, *excludeID)
|
||||||
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
|
conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
|
||||||
}
|
}
|
||||||
args = append(args, limit)
|
args = append(args, limit)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
|
select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
|
||||||
from thoughts
|
from thoughts t
|
||||||
|
join embeddings e on e.thought_id = t.id
|
||||||
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
||||||
order by embedding <=> $1
|
order by e.embedding <=> $1
|
||||||
limit $%d`, len(args))
|
limit $%d`, len(args))
|
||||||
|
|
||||||
rows, err := db.pool.Query(ctx, query, args...)
|
rows, err := db.pool.Query(ctx, query, args...)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
|||||||
thought.ProjectID = &project.ID
|
thought.ProjectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := t.store.InsertThought(ctx, thought)
|
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, CaptureOutput{}, err
|
return nil, CaptureOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ProjectContextOutput{}, err
|
return nil, ProjectContextOutput{}, err
|
||||||
}
|
}
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil)
|
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ProjectContextOutput{}, err
|
return nil, ProjectContextOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RelatedOutput{}, err
|
return nil, RelatedOutput{}, err
|
||||||
}
|
}
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RelatedOutput{}, err
|
return nil, RelatedOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
|||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RecallOutput{}, err
|
return nil, RecallOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,12 +54,13 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
|||||||
return nil, SearchOutput{}, err
|
return nil, SearchOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model := t.provider.EmbeddingModel()
|
||||||
var results []thoughttypes.SearchResult
|
var results []thoughttypes.SearchResult
|
||||||
if project != nil {
|
if project != nil {
|
||||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil)
|
results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil)
|
||||||
_ = t.store.TouchProject(ctx, project.ID)
|
_ = t.store.TouchProject(ctx, project.ID)
|
||||||
} else {
|
} else {
|
||||||
results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{})
|
results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{})
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SearchOutput{}, err
|
return nil, SearchOutput{}, err
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
if project != nil {
|
if project != nil {
|
||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SummarizeOutput{}, err
|
return nil, SummarizeOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := current.Content
|
content := current.Content
|
||||||
embedding := current.Embedding
|
var embedding []float32
|
||||||
mergedMetadata := current.Metadata
|
mergedMetadata := current.Metadata
|
||||||
projectID := current.ProjectID
|
projectID := current.ProjectID
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
|||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID)
|
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, UpdateOutput{}, err
|
return nil, UpdateOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
-- Grant these permissions to the database role used by the application.
|
|
||||||
-- Replace amcs_user with the actual role in your deployment before applying.
|
|
||||||
grant select, insert, update, delete on table public.thoughts to amcs_user;
|
|
||||||
grant select, insert, update, delete on table public.projects to amcs_user;
|
|
||||||
grant select, insert, update, delete on table public.thought_links to amcs_user;
|
|
||||||
16
migrations/007_embeddings_table.sql
Normal file
16
migrations/007_embeddings_table.sql
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
create table if not exists embeddings (
|
||||||
|
id bigserial primary key,
|
||||||
|
guid uuid not null default gen_random_uuid(),
|
||||||
|
thought_id uuid not null references thoughts(id) on delete cascade,
|
||||||
|
model text not null,
|
||||||
|
dim int not null,
|
||||||
|
embedding vector not null,
|
||||||
|
created_at timestamptz default now(),
|
||||||
|
updated_at timestamptz default now(),
|
||||||
|
constraint embeddings_guid_unique unique (guid),
|
||||||
|
constraint embeddings_thought_model_unique unique (thought_id, model)
|
||||||
|
);
|
||||||
|
|
||||||
|
create index if not exists embeddings_thought_id_idx on embeddings (thought_id);
|
||||||
|
|
||||||
|
alter table thoughts drop column if exists embedding;
|
||||||
34
migrations/008_update_match_thoughts.sql
Normal file
34
migrations/008_update_match_thoughts.sql
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
create or replace function match_thoughts(
|
||||||
|
query_embedding vector,
|
||||||
|
match_threshold float default 0.7,
|
||||||
|
match_count int default 10,
|
||||||
|
filter jsonb default '{}'::jsonb,
|
||||||
|
embedding_model text default ''
|
||||||
|
)
|
||||||
|
returns table (
|
||||||
|
id uuid,
|
||||||
|
content text,
|
||||||
|
metadata jsonb,
|
||||||
|
similarity float,
|
||||||
|
created_at timestamptz
|
||||||
|
)
|
||||||
|
language plpgsql
|
||||||
|
as $$
|
||||||
|
begin
|
||||||
|
return query
|
||||||
|
select
|
||||||
|
t.id,
|
||||||
|
t.content,
|
||||||
|
t.metadata,
|
||||||
|
1 - (e.embedding <=> query_embedding) as similarity,
|
||||||
|
t.created_at
|
||||||
|
from thoughts t
|
||||||
|
join embeddings e on e.thought_id = t.id
|
||||||
|
where 1 - (e.embedding <=> query_embedding) > match_threshold
|
||||||
|
and t.archived_at is null
|
||||||
|
and (embedding_model = '' or e.model = embedding_model)
|
||||||
|
and (filter = '{}'::jsonb or t.metadata @> filter)
|
||||||
|
order by e.embedding <=> query_embedding
|
||||||
|
limit match_count;
|
||||||
|
end;
|
||||||
|
$$;
|
||||||
7
migrations/009_rls_and_grants.sql
Normal file
7
migrations/009_rls_and_grants.sql
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
-- Grant these permissions to the database role used by the application.
|
||||||
|
-- Replace amcs_user with the actual role in your deployment before applying.
|
||||||
|
grant ALL ON TABLE public.thoughts to amcs;
|
||||||
|
grant ALL ON TABLE public.projects to amcs;
|
||||||
|
grant ALL ON TABLE public.thought_links to amcs;
|
||||||
|
grant ALL ON TABLE public.embeddings to amcs;
|
||||||
|
GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO amcs;
|
||||||
Reference in New Issue
Block a user