diff --git a/README.md b/README.md index 6fb31a5..540c174 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte | `recall_context` | Semantic + recency context block for injection | | `link_thoughts` | Create a typed relationship between thoughts | | `related_thoughts` | Explicit links + semantic neighbours | +| `backfill_embeddings` | Generate missing embeddings for stored thoughts | ## Configuration @@ -74,6 +75,89 @@ Alternatively, pass `client_id` and `client_secret` as body parameters instead o See `llm/plan.md` for full architecture and implementation plan. +## Backfill + +Run `backfill_embeddings` after switching embedding models or importing thoughts without vectors. + +```json +{ + "project": "optional-project-name", + "limit": 100, + "include_archived": false, + "older_than_days": 0, + "dry_run": false +} +``` + +- `dry_run: true` — report counts without calling the embedding provider +- `limit` — max thoughts per call (default 100) +- Embeddings are generated in parallel (4 workers) and upserted; one failure does not abort the run + +**Automatic backfill** (optional, config-gated): + +```yaml +backfill: + enabled: true + run_on_startup: true # run once on server start + interval: "15m" # repeat every 15 minutes + batch_size: 20 + max_per_run: 100 + include_archived: false +``` + +**Search fallback**: when no embeddings exist for the active model in scope, `search_thoughts`, `recall_context`, `get_project_context`, `summarize_thoughts`, and `related_thoughts` automatically fall back to Postgres full-text search so results are never silently empty. + +## Client Setup + +### Claude Code + +```bash +# API key auth +claude mcp add --transport http amcs http://localhost:8080/mcp --header "x-brain-key: " + +# Bearer token auth +claude mcp add --transport http amcs http://localhost:8080/mcp --header "Authorization: Bearer " +``` + +### OpenAI Codex + +Add to `~/.codex/config.toml`: + +```toml +[[mcp_servers]] +name = "amcs" +url = "http://localhost:8080/mcp" + +[mcp_servers.headers] +x-brain-key = "" +``` + +### OpenCode + +```bash +# API key auth +opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "x-brain-key=" + +# Bearer token auth +opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "Authorization=Bearer " +``` + +Or add directly to `opencode.json` / `~/.config/opencode/config.json`: + +```json +{ + "mcp": { + "amcs": { + "type": "remote", + "url": "http://localhost:8080/mcp", + "headers": { + "x-brain-key": "" + } + } + } +} +``` + ## Development Run the SQL migrations against a local database with: diff --git a/configs/config.example.yaml b/configs/config.example.yaml index 78f3879..53afcc9 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -79,3 +79,11 @@ logging: observability: metrics_enabled: true pprof_enabled: false + +backfill: + enabled: false + run_on_startup: false + interval: "15m" + batch_size: 20 + max_per_run: 100 + include_archived: false diff --git a/internal/app/app.go b/internal/app/app.go index 3dba584..c4acd7d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -74,6 +74,25 @@ func Run(ctx context.Context, configPath string) error { slog.String("provider", provider.Name()), ) + if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup { + go runBackfillPass(ctx, db, provider, cfg.Backfill, logger) + } + + if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 { + go func() { + ticker := time.NewTicker(cfg.Backfill.Interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + runBackfillPass(ctx, db, provider, cfg.Backfill, logger) + } + } + }() + } + server := &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects), @@ -123,6 +142,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects), Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects), Links: tools.NewLinksTool(db, provider, cfg.Search), + Backfill: tools.NewBackfillTool(db, provider, activeProjects, logger), } mcpHandler := mcpserver.New(cfg.MCP, toolSet) @@ -136,6 +156,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger)) } mux.HandleFunc("/favicon.ico", serveFavicon) + mux.HandleFunc("/images/project.jpg", serveHomeImage) mux.HandleFunc("/llm", serveLLMInstructions) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { @@ -155,8 +176,49 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P }) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + w.Header().Set("Allow", "GET, HEAD") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + const homePage = ` + + + + + AMCS + + + +
+ Avelon Memory Crystal project image +
+

Avelon Memory Crystal Server (AMCS)

+

AMCS is a memory server that captures, links, and retrieves structured project thoughts for AI assistants using semantic search, summaries, and MCP tools.

+
+
+ +` + w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("amcs is running")) + if r.Method == http.MethodHead { + return + } + + _, _ = w.Write([]byte(homePage)) }) return observability.Chain( @@ -167,3 +229,39 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P observability.Timeout(cfg.Server.WriteTimeout), ) } + +func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) { + backfiller := tools.NewBackfillTool(db, provider, nil, logger) + _, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{ + Limit: cfg.MaxPerRun, + IncludeArchived: cfg.IncludeArchived, + }) + if err != nil { + logger.Error("auto backfill failed", slog.String("error", err.Error())) + return + } + logger.Info("auto backfill pass", + slog.String("model", out.Model), + slog.Int("scanned", out.Scanned), + slog.Int("embedded", out.Embedded), + slog.Int("failed", out.Failed), + ) +} + +func serveHomeImage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + w.Header().Set("Allow", "GET, HEAD") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "image/jpeg") + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + w.WriteHeader(http.StatusOK) + + if r.Method == http.MethodHead { + return + } + + _, _ = w.Write(homeImage) +} diff --git a/internal/app/favicon.go b/internal/app/favicon.go index b06f267..946d4ef 100644 --- a/internal/app/favicon.go +++ b/internal/app/favicon.go @@ -1,15 +1,9 @@ package app import ( - _ "embed" "net/http" ) -var ( - //go:embed static/favicon.ico - faviconICO []byte -) - func serveFavicon(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "image/x-icon") w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") diff --git a/internal/app/static/avelonmemorycrystal.jpg b/internal/app/static/avelonmemorycrystal.jpg new file mode 100644 index 0000000..f178f6a Binary files /dev/null and b/internal/app/static/avelonmemorycrystal.jpg differ diff --git a/internal/app/static_assets.go b/internal/app/static_assets.go new file mode 100644 index 0000000..d89eb72 --- /dev/null +++ b/internal/app/static_assets.go @@ -0,0 +1,24 @@ +package app + +import ( + "embed" + "fmt" + "io/fs" +) + +var ( + //go:embed static/* + staticFiles embed.FS + + faviconICO = mustReadStaticFile("favicon.ico") + homeImage = mustReadStaticFile("avelonmemorycrystal.jpg") +) + +func mustReadStaticFile(name string) []byte { + data, err := fs.ReadFile(staticFiles, "static/"+name) + if err != nil { + panic(fmt.Sprintf("failed to read embedded static file %q: %v", name, err)) + } + + return data +} diff --git a/internal/config/config.go b/internal/config/config.go index 62a1796..1a57682 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,6 +17,7 @@ type Config struct { Search SearchConfig `yaml:"search"` Logging LoggingConfig `yaml:"logging"` Observability ObservabilityConfig `yaml:"observability"` + Backfill BackfillConfig `yaml:"backfill"` } type ServerConfig struct { @@ -135,3 +136,12 @@ type ObservabilityConfig struct { MetricsEnabled bool `yaml:"metrics_enabled"` PprofEnabled bool `yaml:"pprof_enabled"` } + +type BackfillConfig struct { + Enabled bool `yaml:"enabled"` + RunOnStartup bool `yaml:"run_on_startup"` + Interval time.Duration `yaml:"interval"` + BatchSize int `yaml:"batch_size"` + MaxPerRun int `yaml:"max_per_run"` + IncludeArchived bool `yaml:"include_archived"` +} diff --git a/internal/config/loader.go b/internal/config/loader.go index 47d1589..0a27f04 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -95,6 +95,13 @@ func defaultConfig() Config { Level: "info", Format: "json", }, + Backfill: BackfillConfig{ + Enabled: false, + RunOnStartup: false, + Interval: 15 * time.Minute, + BatchSize: 20, + MaxPerRun: 100, + }, } } diff --git a/internal/config/validate.go b/internal/config/validate.go index d0a7996..2280f7c 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -81,5 +81,14 @@ func (c Config) Validate() error { return fmt.Errorf("invalid config: logging.level is required") } + if c.Backfill.Enabled { + if c.Backfill.BatchSize <= 0 { + return fmt.Errorf("invalid config: backfill.batch_size must be greater than zero when backfill is enabled") + } + if c.Backfill.MaxPerRun < c.Backfill.BatchSize { + return fmt.Errorf("invalid config: backfill.max_per_run must be >= backfill.batch_size") + } + } + return nil } diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 7f41b40..4be4296 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -24,6 +24,7 @@ type ToolSet struct { Recall *tools.RecallTool Summarize *tools.SummarizeTool Links *tools.LinksTool + Backfill *tools.BackfillTool } func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler { @@ -117,6 +118,11 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler { Description: "Retrieve explicit links and semantic neighbors for a thought.", }, toolSet.Links.Related) + addTool(server, &mcp.Tool{ + Name: "backfill_embeddings", + Description: "Generate missing embeddings for stored thoughts using the active embedding model.", + }, toolSet.Backfill.Handle) + return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { return server }, &mcp.StreamableHTTPOptions{ diff --git a/internal/store/thoughts.go b/internal/store/thoughts.go index 9384a0c..42f60b8 100644 --- a/internal/store/thoughts.go +++ b/internal/store/thoughts.go @@ -358,6 +358,140 @@ func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, em return results, nil } +func (db *DB) HasEmbeddingsForModel(ctx context.Context, model string, projectID *uuid.UUID) (bool, error) { + args := []any{model} + conditions := []string{ + "e.model = $1", + "t.archived_at is null", + } + if projectID != nil { + args = append(args, *projectID) + conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args))) + } + + query := `select exists(select 1 from embeddings e join thoughts t on t.guid = e.thought_id where ` + + strings.Join(conditions, " and ") + `)` + + var exists bool + if err := db.pool.QueryRow(ctx, query, args...).Scan(&exists); err != nil { + return false, fmt.Errorf("check embeddings for model: %w", err) + } + return exists, nil +} + +func (db *DB) ListThoughtsMissingEmbedding(ctx context.Context, model string, limit int, projectID *uuid.UUID, includeArchived bool, olderThanDays int) ([]thoughttypes.Thought, error) { + args := []any{model} + conditions := []string{"e.id is null"} + + if !includeArchived { + conditions = append(conditions, "t.archived_at is null") + } + if projectID != nil { + args = append(args, *projectID) + conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args))) + } + if olderThanDays > 0 { + args = append(args, time.Now().Add(-time.Duration(olderThanDays)*24*time.Hour)) + conditions = append(conditions, fmt.Sprintf("t.created_at < $%d", len(args))) + } + args = append(args, limit) + + query := ` + select t.guid, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at + from thoughts t + left join embeddings e on e.thought_id = t.guid and e.model = $1 + where ` + strings.Join(conditions, " and ") + ` + order by t.created_at asc + limit $` + fmt.Sprintf("%d", len(args)) + + rows, err := db.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list thoughts missing embedding: %w", err) + } + defer rows.Close() + + thoughts := make([]thoughttypes.Thought, 0, limit) + for rows.Next() { + var thought thoughttypes.Thought + var metadataBytes []byte + if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan missing-embedding thought: %w", err) + } + if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil { + return nil, fmt.Errorf("decode missing-embedding metadata: %w", err) + } + thoughts = append(thoughts, thought) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate missing-embedding thoughts: %w", err) + } + return thoughts, nil +} + +func (db *DB) UpsertEmbedding(ctx context.Context, thoughtID uuid.UUID, model string, embedding []float32) error { + _, err := db.pool.Exec(ctx, ` + insert into embeddings (thought_id, model, dim, embedding) + values ($1, $2, $3, $4) + on conflict (thought_id, model) do update + set embedding = excluded.embedding, + dim = excluded.dim, + updated_at = now() + `, thoughtID, model, len(embedding), pgvector.NewVector(embedding)) + if err != nil { + return fmt.Errorf("upsert embedding: %w", err) + } + return nil +} + +func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) { + args := []any{query} + conditions := []string{ + "t.archived_at is null", + "to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)", + } + if projectID != nil { + args = append(args, *projectID) + conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args))) + } + if excludeID != nil { + args = append(args, *excludeID) + conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args))) + } + args = append(args, limit) + + q := ` + select t.guid, t.content, t.metadata, + ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity, + t.created_at + from thoughts t + where ` + strings.Join(conditions, " and ") + ` + order by similarity desc + limit $` + fmt.Sprintf("%d", len(args)) + + rows, err := db.pool.Query(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("text search thoughts: %w", err) + } + defer rows.Close() + + results := make([]thoughttypes.SearchResult, 0, limit) + for rows.Next() { + var result thoughttypes.SearchResult + var metadataBytes []byte + if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil { + return nil, fmt.Errorf("scan text search result: %w", err) + } + if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil { + return nil, fmt.Errorf("decode text search metadata: %w", err) + } + results = append(results, result) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate text search results: %w", err) + } + return results, nil +} + func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount { type kv struct { key string diff --git a/internal/tools/backfill.go b/internal/tools/backfill.go new file mode 100644 index 0000000..c92d9c7 --- /dev/null +++ b/internal/tools/backfill.go @@ -0,0 +1,141 @@ +package tools + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/sync/semaphore" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" +) + +const backfillConcurrency = 4 + +type BackfillTool struct { + store *store.DB + provider ai.Provider + sessions *session.ActiveProjects + logger *slog.Logger +} + +type BackfillInput struct { + Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the backfill"` + Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"` + IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"whether to include archived thoughts; defaults to false"` + OlderThanDays int `json:"older_than_days,omitempty" jsonschema:"only backfill thoughts older than N days; 0 means no restriction"` + DryRun bool `json:"dry_run,omitempty" jsonschema:"report counts and sample ids without generating embeddings"` +} + +type BackfillFailure struct { + ID string `json:"id"` + Error string `json:"error"` +} + +type BackfillOutput struct { + Model string `json:"model"` + Scanned int `json:"scanned"` + Embedded int `json:"embedded"` + Skipped int `json:"skipped"` + Failed int `json:"failed"` + DryRun bool `json:"dry_run"` + Failures []BackfillFailure `json:"failures,omitempty"` +} + +func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool { + return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger} +} + +func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in BackfillInput) (*mcp.CallToolResult, BackfillOutput, error) { + limit := in.Limit + if limit <= 0 { + limit = 100 + } + + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, BackfillOutput{}, err + } + + var projectID *uuid.UUID + if project != nil { + projectID = &project.ID + } + + model := t.provider.EmbeddingModel() + + thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays) + if err != nil { + return nil, BackfillOutput{}, err + } + + out := BackfillOutput{ + Model: model, + Scanned: len(thoughts), + DryRun: in.DryRun, + } + + if in.DryRun || len(thoughts) == 0 { + return nil, out, nil + } + + start := time.Now() + sem := semaphore.NewWeighted(backfillConcurrency) + var mu sync.Mutex + var wg sync.WaitGroup + + for _, thought := range thoughts { + if ctx.Err() != nil { + break + } + if err := sem.Acquire(ctx, 1); err != nil { + break + } + wg.Add(1) + go func(id uuid.UUID, content string) { + defer wg.Done() + defer sem.Release(1) + + vec, embedErr := t.provider.Embed(ctx, content) + if embedErr != nil { + mu.Lock() + out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()}) + mu.Unlock() + t.logger.Warn("backfill embed failed", slog.String("thought_id", id.String()), slog.String("error", embedErr.Error())) + return + } + + if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil { + mu.Lock() + out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()}) + mu.Unlock() + t.logger.Warn("backfill upsert failed", slog.String("thought_id", id.String()), slog.String("error", upsertErr.Error())) + return + } + + mu.Lock() + out.Embedded++ + mu.Unlock() + }(thought.ID, thought.Content) + } + + wg.Wait() + + out.Failed = len(out.Failures) + out.Skipped = out.Scanned - out.Embedded - out.Failed + + t.logger.Info("backfill completed", + slog.String("model", model), + slog.Int("scanned", out.Scanned), + slog.Int("embedded", out.Embedded), + slog.Int("failed", out.Failed), + slog.Duration("duration", time.Since(start)), + ) + + return nil, out, nil +} diff --git a/internal/tools/context.go b/internal/tools/context.go index 69dde22..e65168e 100644 --- a/internal/tools/context.go +++ b/internal/tools/context.go @@ -72,11 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P query := strings.TrimSpace(in.Query) if query != "" { - embedding, err := t.provider.Embed(ctx, query) - if err != nil { - return nil, ProjectContextOutput{}, err - } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil) + semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil) if err != nil { return nil, ProjectContextOutput{}, err } diff --git a/internal/tools/links.go b/internal/tools/links.go index 41c27e9..2430016 100644 --- a/internal/tools/links.go +++ b/internal/tools/links.go @@ -117,11 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela } if includeSemantic { - embedding, err := t.provider.Embed(ctx, thought.Content) - if err != nil { - return nil, RelatedOutput{}, err - } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID) + semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID) if err != nil { return nil, RelatedOutput{}, err } diff --git a/internal/tools/recall.go b/internal/tools/recall.go index d420caf..eab6015 100644 --- a/internal/tools/recall.go +++ b/internal/tools/recall.go @@ -48,17 +48,13 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re } limit := normalizeLimit(in.Limit, t.search) - embedding, err := t.provider.Embed(ctx, query) - if err != nil { - return nil, RecallOutput{}, err - } var projectID *uuid.UUID if project != nil { projectID = &project.ID } - semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil) + semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) if err != nil { return nil, RecallOutput{}, err } diff --git a/internal/tools/retrieval.go b/internal/tools/retrieval.go new file mode 100644 index 0000000..f937e74 --- /dev/null +++ b/internal/tools/retrieval.go @@ -0,0 +1,41 @@ +package tools + +import ( + "context" + + "github.com/google/uuid" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +// semanticSearch runs vector similarity search if embeddings exist for the active model +// in the given scope, otherwise falls back to Postgres full-text search. +func semanticSearch( + ctx context.Context, + db *store.DB, + provider ai.Provider, + search config.SearchConfig, + query string, + limit int, + threshold float64, + projectID *uuid.UUID, + excludeID *uuid.UUID, +) ([]thoughttypes.SearchResult, error) { + hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID) + if err != nil { + return nil, err + } + + if hasEmbeddings { + embedding, err := provider.Embed(ctx, query) + if err != nil { + return nil, err + } + return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID) + } + + return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID) +} diff --git a/internal/tools/search.go b/internal/tools/search.go index 4ad0097..ca6684d 100644 --- a/internal/tools/search.go +++ b/internal/tools/search.go @@ -4,6 +4,7 @@ import ( "context" "strings" + "github.com/google/uuid" "github.com/modelcontextprotocol/go-sdk/mcp" "git.warky.dev/wdevs/amcs/internal/ai" @@ -44,24 +45,18 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se limit := normalizeLimit(in.Limit, t.search) threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold) - embedding, err := t.provider.Embed(ctx, query) - if err != nil { - return nil, SearchOutput{}, err - } - project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) if err != nil { return nil, SearchOutput{}, err } - model := t.provider.EmbeddingModel() - var results []thoughttypes.SearchResult + var projectID *uuid.UUID if project != nil { - results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil) + projectID = &project.ID _ = t.store.TouchProject(ctx, project.ID) - } else { - results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{}) } + + results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil) if err != nil { return nil, SearchOutput{}, err } diff --git a/internal/tools/summarize.go b/internal/tools/summarize.go index 0417f5b..6c60939 100644 --- a/internal/tools/summarize.go +++ b/internal/tools/summarize.go @@ -48,15 +48,11 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in count := 0 if query != "" { - embedding, err := t.provider.Embed(ctx, query) - if err != nil { - return nil, SummarizeOutput{}, err - } var projectID *uuid.UUID if project != nil { projectID = &project.ID } - results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil) + results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) if err != nil { return nil, SummarizeOutput{}, err } diff --git a/migrations/010_fulltext_index.sql b/migrations/010_fulltext_index.sql new file mode 100644 index 0000000..48ebb70 --- /dev/null +++ b/migrations/010_fulltext_index.sql @@ -0,0 +1,3 @@ +-- Full-text search index on thought content for semantic fallback when no embeddings exist. +create index if not exists thoughts_content_fts_idx + on thoughts using gin(to_tsvector('simple', content));