feat(embeddings): add embedding model support and related changes

* Introduced EmbeddingModel method in Client and Provider interfaces
* Updated InsertThought and SearchThoughts methods to handle embedding models
* Created embeddings table and updated match_thoughts function for model filtering
* Removed embedding column from thoughts table
* Adjusted permissions for new embeddings table
This commit is contained in:
Hein
2026-03-25 16:25:41 +02:00
parent c8ca272b03
commit cebef3a07c
19 changed files with 259 additions and 88 deletions

View File

@@ -208,6 +208,10 @@ func (c *Client) Name() string {
return c.name return c.name
} }
func (c *Client) EmbeddingModel() string {
return c.embeddingModel
}
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error { func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
body, err := json.Marshal(requestBody) body, err := json.Marshal(requestBody)
if err != nil { if err != nil {

View File

@@ -11,4 +11,5 @@ type Provider interface {
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
Name() string Name() string
EmbeddingModel() string
} }

View File

@@ -40,7 +40,7 @@ func Run(ctx context.Context, configPath string) error {
} }
defer db.Close() defer db.Close()
if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil { if err := db.VerifyRequirements(ctx); err != nil {
return err return err
} }

View File

@@ -0,0 +1,47 @@
package mcpserver
import (
"context"
"fmt"
"reflect"
"github.com/google/jsonschema-go/jsonschema"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
var toolSchemaOptions = &jsonschema.ForOptions{
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
reflect.TypeFor[uuid.UUID](): {
Type: "string",
Format: "uuid",
},
},
}
func addTool[In any, Out any](server *mcp.Server, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
if err := setToolSchemas[In, Out](tool); err != nil {
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
}
mcp.AddTool(server, tool, handler)
}
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
if tool.InputSchema == nil {
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
if err != nil {
return fmt.Errorf("infer input schema: %w", err)
}
tool.InputSchema = inputSchema
}
if tool.OutputSchema == nil {
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
if err != nil {
return fmt.Errorf("infer output schema: %w", err)
}
tool.OutputSchema = outputSchema
}
return nil
}

View File

@@ -0,0 +1,42 @@
package mcpserver
import (
"testing"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/tools"
)
func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
tool := &mcp.Tool{Name: "list_thoughts"}
if err := setToolSchemas[tools.ListInput, tools.ListOutput](tool); err != nil {
t.Fatalf("set tool schemas: %v", err)
}
schema, ok := tool.OutputSchema.(*jsonschema.Schema)
if !ok {
t.Fatalf("output schema type = %T, want *jsonschema.Schema", tool.OutputSchema)
}
thoughtsSchema := schema.Properties["thoughts"]
if thoughtsSchema == nil {
t.Fatal("missing thoughts schema")
}
if thoughtsSchema.Items == nil {
t.Fatal("missing thoughts item schema")
}
idSchema := thoughtsSchema.Items.Properties["id"]
if idSchema == nil {
t.Fatal("missing id schema")
}
if idSchema.Type != "string" {
t.Fatalf("id schema type = %q, want %q", idSchema.Type, "string")
}
if idSchema.Format != "uuid" {
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
}
}

View File

@@ -32,87 +32,87 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
Version: cfg.Version, Version: cfg.Version,
}, nil) }, nil)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "capture_thought", Name: "capture_thought",
Description: "Store a thought with generated embeddings and extracted metadata.", Description: "Store a thought with generated embeddings and extracted metadata.",
}, toolSet.Capture.Handle) }, toolSet.Capture.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "search_thoughts", Name: "search_thoughts",
Description: "Search stored thoughts by semantic similarity.", Description: "Search stored thoughts by semantic similarity.",
}, toolSet.Search.Handle) }, toolSet.Search.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "list_thoughts", Name: "list_thoughts",
Description: "List recent thoughts with optional metadata filters.", Description: "List recent thoughts with optional metadata filters.",
}, toolSet.List.Handle) }, toolSet.List.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "thought_stats", Name: "thought_stats",
Description: "Get counts and top metadata buckets across stored thoughts.", Description: "Get counts and top metadata buckets across stored thoughts.",
}, toolSet.Stats.Handle) }, toolSet.Stats.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "get_thought", Name: "get_thought",
Description: "Retrieve a full thought by id.", Description: "Retrieve a full thought by id.",
}, toolSet.Get.Handle) }, toolSet.Get.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "update_thought", Name: "update_thought",
Description: "Update thought content or merge metadata.", Description: "Update thought content or merge metadata.",
}, toolSet.Update.Handle) }, toolSet.Update.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "delete_thought", Name: "delete_thought",
Description: "Hard-delete a thought by id.", Description: "Hard-delete a thought by id.",
}, toolSet.Delete.Handle) }, toolSet.Delete.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "archive_thought", Name: "archive_thought",
Description: "Archive a thought so it is hidden from default search and listing.", Description: "Archive a thought so it is hidden from default search and listing.",
}, toolSet.Archive.Handle) }, toolSet.Archive.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "create_project", Name: "create_project",
Description: "Create a named project container for thoughts.", Description: "Create a named project container for thoughts.",
}, toolSet.Projects.Create) }, toolSet.Projects.Create)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "list_projects", Name: "list_projects",
Description: "List projects and their current thought counts.", Description: "List projects and their current thought counts.",
}, toolSet.Projects.List) }, toolSet.Projects.List)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "set_active_project", Name: "set_active_project",
Description: "Set the active project for the current MCP session.", Description: "Set the active project for the current MCP session.",
}, toolSet.Projects.SetActive) }, toolSet.Projects.SetActive)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "get_active_project", Name: "get_active_project",
Description: "Return the active project for the current MCP session.", Description: "Return the active project for the current MCP session.",
}, toolSet.Projects.GetActive) }, toolSet.Projects.GetActive)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "get_project_context", Name: "get_project_context",
Description: "Get recent and semantic context for a project.", Description: "Get recent and semantic context for a project.",
}, toolSet.Context.Handle) }, toolSet.Context.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "recall_context", Name: "recall_context",
Description: "Recall semantically relevant and recent context.", Description: "Recall semantically relevant and recent context.",
}, toolSet.Recall.Handle) }, toolSet.Recall.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "summarize_thoughts", Name: "summarize_thoughts",
Description: "Summarize a filtered or searched set of thoughts.", Description: "Summarize a filtered or searched set of thoughts.",
}, toolSet.Summarize.Handle) }, toolSet.Summarize.Handle)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "link_thoughts", Name: "link_thoughts",
Description: "Create a typed relationship between two thoughts.", Description: "Create a typed relationship between two thoughts.",
}, toolSet.Links.Link) }, toolSet.Links.Link)
mcp.AddTool(server, &mcp.Tool{ addTool(server, &mcp.Tool{
Name: "related_thoughts", Name: "related_thoughts",
Description: "Retrieve explicit links and semantic neighbors for a thought.", Description: "Retrieve explicit links and semantic neighbors for a thought.",
}, toolSet.Links.Related) }, toolSet.Links.Related)

View File

@@ -3,7 +3,6 @@ package store
import ( import (
"context" "context"
"fmt" "fmt"
"regexp"
"time" "time"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error {
return db.Ping(readyCtx) return db.Ping(readyCtx)
} }
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error { func (db *DB) VerifyRequirements(ctx context.Context) error {
var hasVector bool var hasVector bool
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil { if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
return fmt.Errorf("verify vector extension: %w", err) return fmt.Errorf("verify vector extension: %w", err)
@@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
return fmt.Errorf("match_thoughts function is missing") return fmt.Errorf("match_thoughts function is missing")
} }
var embeddingType string var hasEmbeddings bool
err := db.pool.QueryRow(ctx, ` if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
select format_type(a.atttypid, a.atttypmod) return fmt.Errorf("verify embeddings table: %w", err)
from pg_attribute a
join pg_class c on c.oid = a.attrelid
join pg_namespace n on n.oid = c.relnamespace
where n.nspname = 'public'
and c.relname = 'thoughts'
and a.attname = 'embedding'
and not a.attisdropped
`).Scan(&embeddingType)
if err != nil {
return fmt.Errorf("verify thoughts.embedding type: %w", err)
} }
if !hasEmbeddings {
re := regexp.MustCompile(`vector\((\d+)\)`) return fmt.Errorf("embeddings table is missing — run migrations")
matches := re.FindStringSubmatch(embeddingType)
if len(matches) != 2 {
return fmt.Errorf("unexpected embedding type %q", embeddingType)
}
var actualDimensions int
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
}
if actualDimensions != dimensions {
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
} }
return nil return nil

View File

@@ -15,27 +15,51 @@ import (
thoughttypes "git.warky.dev/wdevs/amcs/internal/types" thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
) )
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) { func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
metadata, err := json.Marshal(thought.Metadata) metadata, err := json.Marshal(thought.Metadata)
if err != nil { if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err) return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
} }
row := db.pool.QueryRow(ctx, ` tx, err := db.pool.Begin(ctx)
insert into thoughts (content, embedding, metadata, project_id) if err != nil {
values ($1, $2, $3::jsonb, $4) return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
row := tx.QueryRow(ctx, `
insert into thoughts (content, metadata, project_id)
values ($1, $2::jsonb, $3)
returning id, created_at, updated_at returning id, created_at, updated_at
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID) `, thought.Content, metadata, thought.ProjectID)
created := thought created := thought
created.Embedding = nil
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil { if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err) return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
} }
if len(thought.Embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
}
return created, nil return created, nil
} }
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) { func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
filterJSON, err := json.Marshal(filter) filterJSON, err := json.Marshal(filter)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshal search filter: %w", err) return nil, fmt.Errorf("marshal search filter: %w", err)
@@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold
rows, err := db.pool.Query(ctx, ` rows, err := db.pool.Query(ctx, `
select id, content, metadata, similarity, created_at select id, content, metadata, similarity, created_at
from match_thoughts($1, $2, $3, $4::jsonb) from match_thoughts($1, $2, $3, $4::jsonb, $5)
`, pgvector.NewVector(embedding), threshold, limit, filterJSON) `, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
if err != nil { if err != nil {
return nil, fmt.Errorf("search thoughts: %w", err) return nil, fmt.Errorf("search thoughts: %w", err)
} }
@@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) { func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
row := db.pool.QueryRow(ctx, ` row := db.pool.QueryRow(ctx, `
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at select id, content, metadata, project_id, archived_at, created_at, updated_at
from thoughts from thoughts
where id = $1 where id = $1
`, id) `, id)
var thought thoughttypes.Thought var thought thoughttypes.Thought
var embedding pgvector.Vector
var metadataBytes []byte var metadataBytes []byte
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
if err == pgx.ErrNoRows { if err == pgx.ErrNoRows {
return thoughttypes.Thought{}, err return thoughttypes.Thought{}, err
} }
@@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil { if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err) return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
} }
thought.Embedding = embedding.Slice()
return thought, nil return thought, nil
} }
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) { func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
metadataBytes, err := json.Marshal(metadata) metadataBytes, err := json.Marshal(metadata)
if err != nil { if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err) return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
} }
tag, err := db.pool.Exec(ctx, ` tx, err := db.pool.Begin(ctx)
if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback(ctx)
tag, err := tx.Exec(ctx, `
update thoughts update thoughts
set content = $2, set content = $2,
embedding = $3, metadata = $3::jsonb,
metadata = $4::jsonb, project_id = $4,
project_id = $5,
updated_at = now() updated_at = now()
where id = $1 where id = $1
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID) `, id, content, metadataBytes, projectID)
if err != nil { if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err) return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
} }
@@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
return thoughttypes.Thought{}, pgx.ErrNoRows return thoughttypes.Thought{}, pgx.ErrNoRows
} }
if len(embedding) > 0 && embeddingModel != "" {
if _, err := tx.Exec(ctx, `
insert into embeddings (thought_id, model, dim, embedding)
values ($1, $2, $3, $4)
on conflict (thought_id, model) do update
set embedding = excluded.embedding,
dim = excluded.dim,
updated_at = now()
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
}
}
if err := tx.Commit(ctx); err != nil {
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
}
return db.GetThought(ctx, id) return db.GetThought(ctx, id)
} }
@@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in
return db.ListThoughts(ctx, filter) return db.ListThoughts(ctx, filter)
} }
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) { func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
args := []any{pgvector.NewVector(embedding), threshold} args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
conditions := []string{ conditions := []string{
"archived_at is null", "t.archived_at is null",
"1 - (embedding <=> $1) > $2", "1 - (e.embedding <=> $1) > $2",
"e.model = $3",
} }
if projectID != nil { if projectID != nil {
args = append(args, *projectID) args = append(args, *projectID)
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
} }
if excludeID != nil { if excludeID != nil {
args = append(args, *excludeID) args = append(args, *excludeID)
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args))) conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
} }
args = append(args, limit) args = append(args, limit)
query := ` query := `
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
from thoughts from thoughts t
join embeddings e on e.thought_id = t.id
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(` where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
order by embedding <=> $1 order by e.embedding <=> $1
limit $%d`, len(args)) limit $%d`, len(args))
rows, err := db.pool.Query(ctx, query, args...) rows, err := db.pool.Query(ctx, query, args...)

View File

@@ -83,7 +83,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
thought.ProjectID = &project.ID thought.ProjectID = &project.ID
} }
created, err := t.store.InsertThought(ctx, thought) created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
if err != nil { if err != nil {
return nil, CaptureOutput{}, err return nil, CaptureOutput{}, err
} }

View File

@@ -76,7 +76,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
if err != nil { if err != nil {
return nil, ProjectContextOutput{}, err return nil, ProjectContextOutput{}, err
} }
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil) semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil)
if err != nil { if err != nil {
return nil, ProjectContextOutput{}, err return nil, ProjectContextOutput{}, err
} }

View File

@@ -121,7 +121,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
if err != nil { if err != nil {
return nil, RelatedOutput{}, err return nil, RelatedOutput{}, err
} }
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID) semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
if err != nil { if err != nil {
return nil, RelatedOutput{}, err return nil, RelatedOutput{}, err
} }

View File

@@ -58,7 +58,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
projectID = &project.ID projectID = &project.ID
} }
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
if err != nil { if err != nil {
return nil, RecallOutput{}, err return nil, RecallOutput{}, err
} }

View File

@@ -54,12 +54,13 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
return nil, SearchOutput{}, err return nil, SearchOutput{}, err
} }
model := t.provider.EmbeddingModel()
var results []thoughttypes.SearchResult var results []thoughttypes.SearchResult
if project != nil { if project != nil {
results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil) results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil)
_ = t.store.TouchProject(ctx, project.ID) _ = t.store.TouchProject(ctx, project.ID)
} else { } else {
results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{}) results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{})
} }
if err != nil { if err != nil {
return nil, SearchOutput{}, err return nil, SearchOutput{}, err

View File

@@ -56,7 +56,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
if project != nil { if project != nil {
projectID = &project.ID projectID = &project.ID
} }
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
if err != nil { if err != nil {
return nil, SummarizeOutput{}, err return nil, SummarizeOutput{}, err
} }

View File

@@ -48,7 +48,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
} }
content := current.Content content := current.Content
embedding := current.Embedding var embedding []float32
mergedMetadata := current.Metadata mergedMetadata := current.Metadata
projectID := current.ProjectID projectID := current.ProjectID
@@ -79,7 +79,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
projectID = &project.ID projectID = &project.ID
} }
updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID) updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
if err != nil { if err != nil {
return nil, UpdateOutput{}, err return nil, UpdateOutput{}, err
} }

View File

@@ -1,5 +0,0 @@
-- Grant these permissions to the database role used by the application.
-- Replace amcs_user with the actual role in your deployment before applying.
grant select, insert, update, delete on table public.thoughts to amcs_user;
grant select, insert, update, delete on table public.projects to amcs_user;
grant select, insert, update, delete on table public.thought_links to amcs_user;

View File

@@ -0,0 +1,16 @@
create table if not exists embeddings (
id bigserial primary key,
guid uuid not null default gen_random_uuid(),
thought_id uuid not null references thoughts(id) on delete cascade,
model text not null,
dim int not null,
embedding vector not null,
created_at timestamptz default now(),
updated_at timestamptz default now(),
constraint embeddings_guid_unique unique (guid),
constraint embeddings_thought_model_unique unique (thought_id, model)
);
create index if not exists embeddings_thought_id_idx on embeddings (thought_id);
alter table thoughts drop column if exists embedding;

View File

@@ -0,0 +1,34 @@
create or replace function match_thoughts(
query_embedding vector,
match_threshold float default 0.7,
match_count int default 10,
filter jsonb default '{}'::jsonb,
embedding_model text default ''
)
returns table (
id uuid,
content text,
metadata jsonb,
similarity float,
created_at timestamptz
)
language plpgsql
as $$
begin
return query
select
t.id,
t.content,
t.metadata,
1 - (e.embedding <=> query_embedding) as similarity,
t.created_at
from thoughts t
join embeddings e on e.thought_id = t.id
where 1 - (e.embedding <=> query_embedding) > match_threshold
and t.archived_at is null
and (embedding_model = '' or e.model = embedding_model)
and (filter = '{}'::jsonb or t.metadata @> filter)
order by e.embedding <=> query_embedding
limit match_count;
end;
$$;

View File

@@ -0,0 +1,7 @@
-- Grant these permissions to the database role used by the application.
-- Replace amcs_user with the actual role in your deployment before applying.
grant ALL ON TABLE public.thoughts to amcs;
grant ALL ON TABLE public.projects to amcs;
grant ALL ON TABLE public.thought_links to amcs;
grant ALL ON TABLE public.embeddings to amcs;
GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO amcs;