feat(embeddings): add embedding model support and related changes
* Introduced EmbeddingModel method in Client and Provider interfaces * Updated InsertThought and SearchThoughts methods to handle embedding models * Created embeddings table and updated match_thoughts function for model filtering * Removed embedding column from thoughts table * Adjusted permissions for new embeddings table
This commit is contained in:
@@ -208,6 +208,10 @@ func (c *Client) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *Client) EmbeddingModel() string {
|
||||
return c.embeddingModel
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||
body, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,4 +11,5 @@ type Provider interface {
|
||||
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
||||
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||
Name() string
|
||||
EmbeddingModel() string
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func Run(ctx context.Context, configPath string) error {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil {
|
||||
if err := db.VerifyRequirements(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
47
internal/mcpserver/schema.go
Normal file
47
internal/mcpserver/schema.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
var toolSchemaOptions = &jsonschema.ForOptions{
|
||||
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
||||
reflect.TypeFor[uuid.UUID](): {
|
||||
Type: "string",
|
||||
Format: "uuid",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func addTool[In any, Out any](server *mcp.Server, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
|
||||
if err := setToolSchemas[In, Out](tool); err != nil {
|
||||
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
|
||||
}
|
||||
mcp.AddTool(server, tool, handler)
|
||||
}
|
||||
|
||||
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
||||
if tool.InputSchema == nil {
|
||||
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("infer input schema: %w", err)
|
||||
}
|
||||
tool.InputSchema = inputSchema
|
||||
}
|
||||
|
||||
if tool.OutputSchema == nil {
|
||||
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("infer output schema: %w", err)
|
||||
}
|
||||
tool.OutputSchema = outputSchema
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
42
internal/mcpserver/schema_test.go
Normal file
42
internal/mcpserver/schema_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||
)
|
||||
|
||||
func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
|
||||
tool := &mcp.Tool{Name: "list_thoughts"}
|
||||
|
||||
if err := setToolSchemas[tools.ListInput, tools.ListOutput](tool); err != nil {
|
||||
t.Fatalf("set tool schemas: %v", err)
|
||||
}
|
||||
|
||||
schema, ok := tool.OutputSchema.(*jsonschema.Schema)
|
||||
if !ok {
|
||||
t.Fatalf("output schema type = %T, want *jsonschema.Schema", tool.OutputSchema)
|
||||
}
|
||||
|
||||
thoughtsSchema := schema.Properties["thoughts"]
|
||||
if thoughtsSchema == nil {
|
||||
t.Fatal("missing thoughts schema")
|
||||
}
|
||||
if thoughtsSchema.Items == nil {
|
||||
t.Fatal("missing thoughts item schema")
|
||||
}
|
||||
|
||||
idSchema := thoughtsSchema.Items.Properties["id"]
|
||||
if idSchema == nil {
|
||||
t.Fatal("missing id schema")
|
||||
}
|
||||
if idSchema.Type != "string" {
|
||||
t.Fatalf("id schema type = %q, want %q", idSchema.Type, "string")
|
||||
}
|
||||
if idSchema.Format != "uuid" {
|
||||
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
|
||||
}
|
||||
}
|
||||
@@ -32,87 +32,87 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
||||
Version: cfg.Version,
|
||||
}, nil)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "capture_thought",
|
||||
Description: "Store a thought with generated embeddings and extracted metadata.",
|
||||
}, toolSet.Capture.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "search_thoughts",
|
||||
Description: "Search stored thoughts by semantic similarity.",
|
||||
}, toolSet.Search.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "list_thoughts",
|
||||
Description: "List recent thoughts with optional metadata filters.",
|
||||
}, toolSet.List.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "thought_stats",
|
||||
Description: "Get counts and top metadata buckets across stored thoughts.",
|
||||
}, toolSet.Stats.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "get_thought",
|
||||
Description: "Retrieve a full thought by id.",
|
||||
}, toolSet.Get.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "update_thought",
|
||||
Description: "Update thought content or merge metadata.",
|
||||
}, toolSet.Update.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "delete_thought",
|
||||
Description: "Hard-delete a thought by id.",
|
||||
}, toolSet.Delete.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "archive_thought",
|
||||
Description: "Archive a thought so it is hidden from default search and listing.",
|
||||
}, toolSet.Archive.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "create_project",
|
||||
Description: "Create a named project container for thoughts.",
|
||||
}, toolSet.Projects.Create)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "list_projects",
|
||||
Description: "List projects and their current thought counts.",
|
||||
}, toolSet.Projects.List)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "set_active_project",
|
||||
Description: "Set the active project for the current MCP session.",
|
||||
}, toolSet.Projects.SetActive)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "get_active_project",
|
||||
Description: "Return the active project for the current MCP session.",
|
||||
}, toolSet.Projects.GetActive)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "get_project_context",
|
||||
Description: "Get recent and semantic context for a project.",
|
||||
}, toolSet.Context.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "recall_context",
|
||||
Description: "Recall semantically relevant and recent context.",
|
||||
}, toolSet.Recall.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "summarize_thoughts",
|
||||
Description: "Summarize a filtered or searched set of thoughts.",
|
||||
}, toolSet.Summarize.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "link_thoughts",
|
||||
Description: "Create a typed relationship between two thoughts.",
|
||||
}, toolSet.Links.Link)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "related_thoughts",
|
||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||
}, toolSet.Links.Related)
|
||||
|
||||
@@ -3,7 +3,6 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
@@ -68,7 +67,7 @@ func (db *DB) Ready(ctx context.Context) error {
|
||||
return db.Ping(readyCtx)
|
||||
}
|
||||
|
||||
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
func (db *DB) VerifyRequirements(ctx context.Context) error {
|
||||
var hasVector bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
||||
return fmt.Errorf("verify vector extension: %w", err)
|
||||
@@ -85,33 +84,12 @@ func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
return fmt.Errorf("match_thoughts function is missing")
|
||||
}
|
||||
|
||||
var embeddingType string
|
||||
err := db.pool.QueryRow(ctx, `
|
||||
select format_type(a.atttypid, a.atttypmod)
|
||||
from pg_attribute a
|
||||
join pg_class c on c.oid = a.attrelid
|
||||
join pg_namespace n on n.oid = c.relnamespace
|
||||
where n.nspname = 'public'
|
||||
and c.relname = 'thoughts'
|
||||
and a.attname = 'embedding'
|
||||
and not a.attisdropped
|
||||
`).Scan(&embeddingType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify thoughts.embedding type: %w", err)
|
||||
var hasEmbeddings bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_tables where schemaname = 'public' and tablename = 'embeddings')`).Scan(&hasEmbeddings); err != nil {
|
||||
return fmt.Errorf("verify embeddings table: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`vector\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(embeddingType)
|
||||
if len(matches) != 2 {
|
||||
return fmt.Errorf("unexpected embedding type %q", embeddingType)
|
||||
}
|
||||
|
||||
var actualDimensions int
|
||||
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
|
||||
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
|
||||
}
|
||||
if actualDimensions != dimensions {
|
||||
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
|
||||
if !hasEmbeddings {
|
||||
return fmt.Errorf("embeddings table is missing — run migrations")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -15,27 +15,51 @@ import (
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, embeddingModel string) (thoughttypes.Thought, error) {
|
||||
metadata, err := json.Marshal(thought.Metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into thoughts (content, embedding, metadata, project_id)
|
||||
values ($1, $2, $3::jsonb, $4)
|
||||
tx, err := db.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
row := tx.QueryRow(ctx, `
|
||||
insert into thoughts (content, metadata, project_id)
|
||||
values ($1, $2::jsonb, $3)
|
||||
returning id, created_at, updated_at
|
||||
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
|
||||
`, thought.Content, metadata, thought.ProjectID)
|
||||
|
||||
created := thought
|
||||
created.Embedding = nil
|
||||
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
||||
}
|
||||
|
||||
if len(thought.Embedding) > 0 && embeddingModel != "" {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, created.ID, embeddingModel, len(thought.Embedding), pgvector.NewVector(thought.Embedding)); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal search filter: %w", err)
|
||||
@@ -43,8 +67,8 @@ func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold
|
||||
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select id, content, metadata, similarity, created_at
|
||||
from match_thoughts($1, $2, $3, $4::jsonb)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
|
||||
from match_thoughts($1, $2, $3, $4::jsonb, $5)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON, embeddingModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search thoughts: %w", err)
|
||||
}
|
||||
@@ -185,15 +209,14 @@ func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
||||
|
||||
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
|
||||
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
where id = $1
|
||||
`, id)
|
||||
|
||||
var thought thoughttypes.Thought
|
||||
var embedding pgvector.Vector
|
||||
var metadataBytes []byte
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Thought{}, err
|
||||
}
|
||||
@@ -203,26 +226,30 @@ func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Though
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
||||
}
|
||||
thought.Embedding = embedding.Slice()
|
||||
|
||||
return thought, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, embeddingModel string, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
||||
}
|
||||
|
||||
tag, err := db.pool.Exec(ctx, `
|
||||
tx, err := db.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
tag, err := tx.Exec(ctx, `
|
||||
update thoughts
|
||||
set content = $2,
|
||||
embedding = $3,
|
||||
metadata = $4::jsonb,
|
||||
project_id = $5,
|
||||
set content = $2,
|
||||
metadata = $3::jsonb,
|
||||
project_id = $4,
|
||||
updated_at = now()
|
||||
where id = $1
|
||||
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
|
||||
`, id, content, metadataBytes, projectID)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
||||
}
|
||||
@@ -230,6 +257,23 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
|
||||
return thoughttypes.Thought{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
if len(embedding) > 0 && embeddingModel != "" {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, id, embeddingModel, len(embedding), pgvector.NewVector(embedding)); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("upsert embedding: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("commit thought update: %w", err)
|
||||
}
|
||||
|
||||
return db.GetThought(ctx, id)
|
||||
}
|
||||
|
||||
@@ -265,27 +309,29 @@ func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit in
|
||||
return db.ListThoughts(ctx, filter)
|
||||
}
|
||||
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold}
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, embeddingModel string, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold, embeddingModel}
|
||||
conditions := []string{
|
||||
"archived_at is null",
|
||||
"1 - (embedding <=> $1) > $2",
|
||||
"t.archived_at is null",
|
||||
"1 - (e.embedding <=> $1) > $2",
|
||||
"e.model = $3",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
|
||||
conditions = append(conditions, fmt.Sprintf("t.id <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
|
||||
from thoughts
|
||||
select t.id, t.content, t.metadata, 1 - (e.embedding <=> $1) as similarity, t.created_at
|
||||
from thoughts t
|
||||
join embeddings e on e.thought_id = t.id
|
||||
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
||||
order by embedding <=> $1
|
||||
order by e.embedding <=> $1
|
||||
limit $%d`, len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
|
||||
@@ -83,7 +83,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
||||
thought.ProjectID = &project.ID
|
||||
}
|
||||
|
||||
created, err := t.store.InsertThought(ctx, thought)
|
||||
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
|
||||
if err != nil {
|
||||
return nil, CaptureOutput{}, err
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil)
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
@@ -54,12 +54,13 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
model := t.provider.EmbeddingModel()
|
||||
var results []thoughttypes.SearchResult
|
||||
if project != nil {
|
||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil)
|
||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil)
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
} else {
|
||||
results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{})
|
||||
results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
|
||||
@@ -56,7 +56,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
||||
}
|
||||
|
||||
content := current.Content
|
||||
embedding := current.Embedding
|
||||
var embedding []float32
|
||||
mergedMetadata := current.Metadata
|
||||
projectID := current.ProjectID
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID)
|
||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user